From 45dd5977143069b5eaffcaca1aa030d0f527da2e Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 8 Dec 2025 13:46:44 -0500 Subject: [PATCH 001/279] fix: fix broken tool spec with composition keywords (#1301) --- src/strands/tools/registry.py | 5 +- src/strands/tools/tools.py | 10 ++- tests/strands/tools/test_registry.py | 122 +++++++++++++++++++++++++++ tests/strands/tools/test_tools.py | 15 ++++ 4 files changed, 149 insertions(+), 3 deletions(-) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 91f0bf870..15150847d 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -22,7 +22,7 @@ from ..tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec from .loader import load_tool_from_string, load_tools_from_module -from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec +from .tools import _COMPOSITION_KEYWORDS, PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -604,7 +604,8 @@ def validate_tool_spec(self, tool_spec: ToolSpec) -> None: if "$ref" in prop_def: continue - if "type" not in prop_def: + has_composition = any(kw in prop_def for kw in _COMPOSITION_KEYWORDS) + if "type" not in prop_def and not has_composition: prop_def["type"] = "string" if "description" not in prop_def: prop_def["description"] = f"Property {prop_name}" diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 48b969bc3..39e2f3723 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -17,6 +17,12 @@ logger = logging.getLogger(__name__) +_COMPOSITION_KEYWORDS = ("anyOf", "oneOf", "allOf", "not") +"""JSON Schema composition keywords that define type constraints. + +Properties with these should not get a default type: "string" added. +""" + class InvalidToolUseNameException(Exception): """Exception raised when a tool use has an invalid name.""" @@ -88,7 +94,9 @@ def _normalize_property(prop_name: str, prop_def: Any) -> dict[str, Any]: if "$ref" in normalized_prop: return normalized_prop - normalized_prop.setdefault("type", "string") + has_composition = any(kw in normalized_prop for kw in _COMPOSITION_KEYWORDS) + if not has_composition: + normalized_prop.setdefault("type", "string") normalized_prop.setdefault("description", f"Property {prop_name}") return normalized_prop diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index c700016f6..9ae51dcfe 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -389,3 +389,125 @@ async def track_load_tools(*args, **kwargs): # Verify add_consumer was called with the registry ID mock_provider.add_consumer.assert_called_once_with(registry._registry_id) + + +def test_validate_tool_spec_with_anyof_property(): + """Test that validate_tool_spec does not add type: 'string' to anyOf properties. + + This is important for MCP tools that use anyOf for optional/union types like + Optional[List[str]]. Adding type: 'string' causes models to return string-encoded + JSON instead of proper arrays/objects. + """ + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "regular_field": {}, # Should get type: "string" + "anyof_field": { + "anyOf": [ + {"type": "array", "items": {"type": "string"}}, + {"type": "null"}, + ] + }, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # Regular field should get default type: "string" + assert props["regular_field"]["type"] == "string" + assert props["regular_field"]["description"] == "Property regular_field" + + # anyOf field should NOT get type: "string" added + assert "type" not in props["anyof_field"], "anyOf property should not have type added" + assert "anyOf" in props["anyof_field"], "anyOf should be preserved" + assert props["anyof_field"]["description"] == "Property anyof_field" + + +def test_validate_tool_spec_with_composition_keywords(): + """Test that validate_tool_spec does not add type: 'string' to composition keyword properties. + + JSON Schema composition keywords (anyOf, oneOf, allOf, not) define type constraints. + Properties using these should not get a default type added. + """ + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "regular_field": {}, # Should get type: "string" + "oneof_field": { + "oneOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + "allof_field": { + "allOf": [ + {"minimum": 0}, + {"maximum": 100}, + ] + }, + "not_field": {"not": {"type": "null"}}, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # Regular field should get default type: "string" + assert props["regular_field"]["type"] == "string" + + # Composition keyword fields should NOT get type: "string" added + assert "type" not in props["oneof_field"], "oneOf property should not have type added" + assert "oneOf" in props["oneof_field"], "oneOf should be preserved" + + assert "type" not in props["allof_field"], "allOf property should not have type added" + assert "allOf" in props["allof_field"], "allOf should be preserved" + + assert "type" not in props["not_field"], "not property should not have type added" + assert "not" in props["not_field"], "not should be preserved" + + # All should have descriptions + for field in ["oneof_field", "allof_field", "not_field"]: + assert props[field]["description"] == f"Property {field}" + + +def test_validate_tool_spec_with_ref_property(): + """Test that validate_tool_spec does not modify $ref properties.""" + tool_spec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "ref_field": {"$ref": "#/$defs/SomeType"}, + }, + } + }, + } + + registry = ToolRegistry() + registry.validate_tool_spec(tool_spec) + + props = tool_spec["inputSchema"]["json"]["properties"] + + # $ref field should not be modified + assert props["ref_field"] == {"$ref": "#/$defs/SomeType"} + assert "type" not in props["ref_field"] + assert "description" not in props["ref_field"] diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 60460f464..e20274523 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -509,3 +509,18 @@ async def test_stream(identity_tool, alist): tru_events = await alist(stream) exp_events = [ToolResultEvent(({"tool_use": 1}, 2))] assert tru_events == exp_events + + +def test_normalize_schema_with_anyof(): + """Test that anyOf properties don't get default type.""" + schema = { + "type": "object", + "properties": { + "optional_field": {"anyOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]}, + "regular_field": {}, + }, + } + normalized = normalize_schema(schema) + + assert "type" not in normalized["properties"]["optional_field"] + assert normalized["properties"]["regular_field"]["type"] == "string" From 6543097096b968730e205d620b94a95ea09f3cca Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 9 Dec 2025 15:02:20 -0500 Subject: [PATCH 002/279] bidi - tests - lint (#1307) --- .../experimental/bidi/agent/__init__.py | 2 +- .../experimental/bidi/agent/test_agent.py | 142 +++++++++--------- .../experimental/bidi/agent/test_loop.py | 14 +- .../experimental/bidi/io/test_audio.py | 3 +- .../strands/experimental/bidi/io/test_text.py | 2 +- .../bidi/models/test_gemini_live.py | 7 +- .../bidi/models/test_nova_sonic.py | 18 +-- .../bidi/models/test_openai_realtime.py | 10 +- tests_integ/bidi/tools/test_direct.py | 5 +- 9 files changed, 100 insertions(+), 103 deletions(-) diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py index 3359c6565..dd401a83d 100644 --- a/tests/strands/experimental/bidi/agent/__init__.py +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -1 +1 @@ -"""Bidirectional streaming agent tests.""" \ No newline at end of file +"""Bidirectional streaming agent tests.""" diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 19d3525d7..7b03ab717 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,21 +1,23 @@ """Unit tests for BidiAgent.""" -import unittest.mock import asyncio -import pytest +import unittest.mock from uuid import uuid4 +import pytest + from strands.experimental.bidi.agent.agent import BidiAgent from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel from strands.experimental.bidi.types.events import ( - BidiTextInputEvent, BidiAudioInputEvent, BidiAudioStreamEvent, - BidiTranscriptStreamEvent, - BidiConnectionStartEvent, BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, ) + class MockBidiModel: """Mock bidirectional model for testing.""" @@ -46,14 +48,14 @@ async def receive(self): """Async generator yielding mock events.""" if not self._started: raise RuntimeError("model not started | call start before sending/receiving") - + # Yield connection start event yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - + # Yield any configured events for event in self._events_to_yield: yield event - + # Yield connection end event yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") @@ -61,11 +63,13 @@ def set_events(self, events): """Helper to set events this mock model will yield.""" self._events_to_yield = events + @pytest.fixture def mock_model(): """Create a mock BidiModel instance.""" return MockBidiModel() + @pytest.fixture def mock_tool_registry(): """Mock tool registry with some basic tools.""" @@ -73,15 +77,15 @@ def mock_tool_registry(): registry.get_all_tool_specs.return_value = [ { "name": "calculator", - "description": "Perform calculations", - "inputSchema": {"json": {"type": "object", "properties": {}}} + "description": "Perform calculations", + "inputSchema": {"json": {"type": "object", "properties": {}}}, } ] registry.get_all_tools_config.return_value = {"calculator": {}} return registry -@pytest.fixture +@pytest.fixture def mock_tool_caller(): """Mock tool caller for testing tool execution.""" caller = unittest.mock.AsyncMock() @@ -94,203 +98,194 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller): """Create a BidiAgent instance for testing.""" with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class: mock_registry_class.return_value = mock_tool_registry - + with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class: mock_caller_class.return_value = mock_tool_caller - + # Don't pass tools to avoid real tool loading agent = BidiAgent(model=mock_model) return agent + def test_bidi_agent_init_with_various_configurations(): """Test agent initialization with various configurations.""" # Test default initialization mock_model = MockBidiModel() agent = BidiAgent(model=mock_model) - + assert agent.model == mock_model assert agent.system_prompt is None assert not agent._started assert agent.model._connection_id is None - + # Test with configuration system_prompt = "You are a helpful assistant." - agent_with_config = BidiAgent( - model=mock_model, - system_prompt=system_prompt, - agent_id="test_agent" - ) - + agent_with_config = BidiAgent(model=mock_model, system_prompt=system_prompt, agent_id="test_agent") + assert agent_with_config.system_prompt == system_prompt assert agent_with_config.agent_id == "test_agent" - + # Test with string model ID model_id = "amazon.nova-sonic-v1:0" agent_with_string = BidiAgent(model=model_id) - + assert isinstance(agent_with_string.model, BidiNovaSonicModel) assert agent_with_string.model.model_id == model_id - + # Test model config access config = agent.model.config assert config["audio"]["input_rate"] == 16000 assert config["audio"]["output_rate"] == 24000 assert config["audio"]["channels"] == 1 + @pytest.mark.asyncio async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start agent await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Double start should error with pytest.raises(RuntimeError, match="agent already started"): await agent.start() - + # Stop agent await agent.stop() assert not agent._started assert agent.model._connection_id is None - + # Multiple stops should be safe await agent.stop() await agent.stop() - + # Restart should work with new connection ID await agent.start() assert agent._started assert agent.model._connection_id != connection_id + @pytest.mark.asyncio async def test_bidi_agent_send_with_input_types(agent): """Test sending various input types through agent.send().""" await agent.start() - + # Test text input with TypedEvent text_input = BidiTextInputEvent(text="Hello", role="user") await agent.send(text_input) assert len(agent.messages) == 1 assert agent.messages[0]["content"][0]["text"] == "Hello" - + # Test string input (shorthand) await agent.send("World") assert len(agent.messages) == 2 assert agent.messages[1]["content"][0]["text"] == "World" - + # Test audio input (doesn't add to messages) audio_input = BidiAudioInputEvent( audio="dGVzdA==", # base64 "test" format="pcm", sample_rate=16000, - channels=1 + channels=1, ) await agent.send(audio_input) assert len(agent.messages) == 2 # Still 2, audio doesn't add - + # Test concurrent sends - sends = [ - agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) - for i in range(3) - ] + sends = [agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) for i in range(3)] await asyncio.gather(*sends) assert len(agent.messages) == 5 # 2 + 3 new messages + @pytest.mark.asyncio async def test_bidi_agent_receive_events_from_model(agent): """Test receiving events from model.""" # Configure mock model to yield events events = [ - BidiAudioStreamEvent( - audio="dGVzdA==", - format="pcm", - sample_rate=24000, - channels=1 - ), + BidiAudioStreamEvent(audio="dGVzdA==", format="pcm", sample_rate=24000, channels=1), BidiTranscriptStreamEvent( text="Hello world", role="assistant", is_final=True, delta={"text": "Hello world"}, - current_transcript="Hello world" - ) + current_transcript="Hello world", + ), ] agent.model.set_events(events) - + await agent.start() - + received_events = [] async for event in agent.receive(): received_events.append(event) if len(received_events) >= 4: # Stop after getting expected events break - + # Verify event types and order assert len(received_events) >= 3 assert isinstance(received_events[0], BidiConnectionStartEvent) assert isinstance(received_events[1], BidiAudioStreamEvent) assert isinstance(received_events[2], BidiTranscriptStreamEvent) - + # Test empty events agent.model.set_events([]) await agent.stop() await agent.start() - + empty_events = [] async for event in agent.receive(): empty_events.append(event) if len(empty_events) >= 2: break - + assert len(empty_events) >= 1 assert isinstance(empty_events[0], BidiConnectionStartEvent) + def test_bidi_agent_tool_integration(agent, mock_tool_registry): """Test agent tool integration and properties.""" # Test tool property access - assert hasattr(agent, 'tool') + assert hasattr(agent, "tool") assert agent.tool is not None assert agent.tool == agent._tool_caller - + # Test tool names property - mock_tool_registry.get_all_tools_config.return_value = { - "calculator": {}, - "weather": {} - } - + mock_tool_registry.get_all_tools_config.return_value = {"calculator": {}, "weather": {}} + tool_names = agent.tool_names assert isinstance(tool_names, list) assert len(tool_names) == 2 assert "calculator" in tool_names assert "weather" in tool_names + @pytest.mark.asyncio async def test_bidi_agent_send_receive_error_before_start(agent): """Test error handling in various scenarios.""" # Test send before start with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive before start with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass - + # Test send after stop await agent.start() await agent.stop() with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive after stop with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass @@ -301,43 +296,44 @@ async def test_bidi_agent_start_receive_propagates_model_errors(): mock_model = MockBidiModel() mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed")) error_agent = BidiAgent(model=mock_model) - + with pytest.raises(Exception, match="Connection failed"): await error_agent.start() - + # Test model receive error mock_model2 = MockBidiModel() agent2 = BidiAgent(model=mock_model2) await agent2.start() - + async def failing_receive(): yield BidiConnectionStartEvent(connection_id="test", model="test-model") raise Exception("Receive failed") - + agent2.model.receive = failing_receive with pytest.raises(Exception, match="Receive failed"): - async for event in agent2.receive(): + async for _ in agent2.receive(): pass + @pytest.mark.asyncio async def test_bidi_agent_state_consistency(agent): """Test that agent state remains consistent across operations.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Send operations shouldn't change connection state await agent.send(BidiTextInputEvent(text="Hello", role="user")) assert agent._started assert agent.model._connection_id == connection_id - + # Stop await agent.stop() assert not agent._started - assert agent.model._connection_id is None \ No newline at end of file + assert agent.model._connection_id is None diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 0ce8d6658..da8578f55 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -5,7 +5,6 @@ from strands import tool from strands.experimental.bidi import BidiAgent -from strands.experimental.bidi.agent.loop import _BidiAgentLoop from strands.experimental.bidi.models import BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -38,19 +37,19 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) if len(tru_events) >= 2: break - + exp_events = [ BidiConnectionRestartEvent(timeout_error), text_event, ] assert tru_events == exp_events - + agent.model.stop.assert_called_once() assert agent.model.start.call_count == 2 agent.model.start.assert_called_with( @@ -63,7 +62,6 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato @pytest.mark.asyncio async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): - tool_use = {"toolUseId": "t1", "name": "time_tool", "input": {}} tool_result = {"toolUseId": "t1", "status": "success", "content": [{"text": "12:00"}]} @@ -71,9 +69,9 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): tool_result_event = ToolResultEvent(tool_result) agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) - + await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) @@ -86,7 +84,7 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): ToolResultMessageEvent({"role": "user", "content": [{"toolResult": tool_result}]}), ] assert tru_events == exp_events - + tru_messages = agent.messages exp_messages = [ {"role": "assistant", "content": [{"toolUse": tool_use}]}, diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index 459faa78a..9b502700b 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -29,7 +29,7 @@ def agent(): "voice": "test-voice", }, } - return mock + return mock @pytest.fixture @@ -49,6 +49,7 @@ def config(): "output_frames_per_buffer": 2048, } + @pytest.fixture def audio_io(py_audio, config): _ = py_audio diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py index 5507a8c0f..e21e149bd 100644 --- a/tests/strands/experimental/bidi/io/test_text.py +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -42,7 +42,7 @@ async def test_bidi_text_io_input(prompt_session, text_input): (BidiInterruptionEvent(reason="user_speech"), "interrupted"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=False, role="user"), "Preview: test text"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=True, role="user"), "test text"), - ] + ], ) @pytest.mark.asyncio async def test_bidi_text_io_output(event, exp_print, text_output, capsys): diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..79bb29d41 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 04f8043be..933fd2088 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -13,10 +13,10 @@ import pytest_asyncio from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, ) -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -538,12 +538,12 @@ async def test_custom_audio_rates_in_events(model_id, region): audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use configured rates, not constants assert result.sample_rate == 48000 # Custom config - assert result.channels == 2 # Custom config + assert result.channels == 2 # Custom config assert result.format == "pcm" @@ -558,12 +558,12 @@ async def test_default_audio_rates_in_events(model_id, region): audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use default rates assert result.sample_rate == 16000 # Default output rate - assert result.channels == 1 # Default channels + assert result.channels == 1 # Default channels assert result.format == "pcm" @@ -573,9 +573,9 @@ async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): mock_output = AsyncMock() mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): async for _ in nova_model.receive(): pass @@ -586,9 +586,9 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock mock_output = AsyncMock() mock_output.receive.side_effect = ValidationException("InternalErrorCode=531: Request timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): async for _ in nova_model.receive(): pass diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..a973e80aa 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events diff --git a/tests_integ/bidi/tools/test_direct.py b/tests_integ/bidi/tools/test_direct.py index 30320e786..1694d64b6 100644 --- a/tests_integ/bidi/tools/test_direct.py +++ b/tests_integ/bidi/tools/test_direct.py @@ -28,15 +28,14 @@ def test_bidi_agent_tool_direct_call(agent): "toolUseId": unittest.mock.ANY, } assert tru_result == exp_result - + tru_messages = agent.messages exp_messages = [ { "content": [ { "text": ( - "agent.tool.weather_tool direct tool call.\n" - 'Input parameters: {"city_name": "new york"}\n' + 'agent.tool.weather_tool direct tool call.\nInput parameters: {"city_name": "new york"}\n' ), }, ], From e692133fe7effdbaf52cbf11ad3255c9cebcd0d3 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 9 Dec 2025 15:02:50 -0500 Subject: [PATCH 003/279] bidi - fix mypy errors (#1308) --- src/strands/tools/_caller.py | 6 ++- src/strands/tools/executors/_executor.py | 52 +++++++++++++---------- src/strands/tools/executors/concurrent.py | 9 ++-- src/strands/tools/executors/sequential.py | 5 +-- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 4a74dec18..1e0ca2c8d 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -106,8 +106,10 @@ async def acall() -> ToolResult: tool_result = run_async(acall) - # Apply conversation management if agent supports it (traditional agents) - if hasattr(self._agent, "conversation_manager"): + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + from ..agent import Agent + + if isinstance(self._agent, Agent): self._agent.conversation_manager.apply_management(self._agent) return tool_result diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index a4f9e7e1f..5d01c5d48 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -49,16 +49,19 @@ async def _invoke_before_tool_call_hook( invocation_state: dict[str, Any], ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + kwargs = { + "selected_tool": tool_func, + "tool_use": tool_use, + "invocation_state": invocation_state, + } + event = ( + BeforeToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiBeforeToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _invoke_after_tool_call_hook( agent: "Agent | BidiAgent", @@ -70,19 +73,22 @@ async def _invoke_after_tool_call_hook( cancel_message: str | None = None, ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, - ) + kwargs = { + "selected_tool": selected_tool, + "tool_use": tool_use, + "invocation_state": invocation_state, + "result": result, + "exception": exception, + "cancel_message": cancel_message, + } + event = ( + AfterToolCallEvent(agent=cast("Agent", agent), **kwargs) + if ToolExecutor._is_agent(agent) + else BidiAfterToolCallEvent(agent=cast("BidiAgent", agent), **kwargs) ) + return await agent.hooks.invoke_callbacks_async(event) + @staticmethod async def _stream( agent: "Agent | BidiAgent", @@ -247,7 +253,7 @@ async def _stream( @staticmethod async def _stream_with_trace( - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -259,7 +265,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent (Agent or BidiAgent) for which the tool is being executed. + agent: The agent for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -308,7 +314,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -319,7 +325,7 @@ def _execute( """Execute the given tools according to this executor's strategy. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index da5c1ff10..216eee379 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -22,7 +21,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +32,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -79,7 +78,7 @@ async def _execute( async def _task( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -94,7 +93,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent (Agent or BidiAgent) executing the tool. + agent: The agent executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 6163fc195..f78e60872 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +20,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -34,7 +33,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. From 9f70298a3036c0dbfe2b4b3f48aea9a065812a7c Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Wed, 10 Dec 2025 21:41:45 +0400 Subject: [PATCH 004/279] feat(hooks): add AgentResult to AfterInvocationEvent (#1125) The AfterInvocationEvent hook event did not provide access to the AgentResult, making it difficult for hooks to perform actions based on the outcome of an agent's invocation. This PR updates the AfterInvocationEvent to include an optional AgentResult. The Agent.invoke method now captures the AgentResult and passes it to the AfterInvocationEvent. New tests have been added to verify the functionality and immutability of the event. --- src/strands/agent/agent.py | 9 +++++++-- src/strands/hooks/events.py | 12 +++++++++++- tests/strands/agent/hooks/test_events.py | 21 +++++++++++++++++++++ tests/strands/agent/test_agent_hooks.py | 12 +++++++----- 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d6b08eff0..8fc5be6ca 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -56,7 +56,7 @@ from ..tools.registry import ToolRegistry from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher -from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent +from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException @@ -621,6 +621,7 @@ async def _run_loop( """ await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + agent_result: AgentResult | None = None try: yield InitEventLoopEvent() @@ -648,9 +649,13 @@ async def _run_loop( self._session_manager.redact_latest_message(self.messages[-1], self) yield event + # Capture the result from the final event if available + if isinstance(event, EventLoopStopEvent): + agent_result = AgentResult(*event["stop"]) + finally: self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result)) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 05be255f6..ebc508f24 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -5,10 +5,13 @@ import uuid from dataclasses import dataclass -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from typing_extensions import override +if TYPE_CHECKING: + from ..agent.agent_result import AgentResult + from ..types.content import Message from ..types.interrupt import _Interruptible from ..types.streaming import StopReason @@ -60,8 +63,15 @@ class AfterInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + result: The result of the agent invocation, if available. + This will be None when invoked from structured_output methods, as those return typed output directly rather + than AgentResult. """ + result: "AgentResult | None" = None + @property def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 8bbd89c17..9203478b2 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -2,6 +2,7 @@ import pytest +from strands.agent.agent_result import AgentResult from strands.hooks import ( AfterInvocationEvent, AfterToolCallEvent, @@ -10,6 +11,7 @@ BeforeToolCallEvent, MessageAddedEvent, ) +from strands.types.content import Message from strands.types.tools import ToolResult, ToolUse @@ -138,3 +140,22 @@ def test_after_tool_invocation_event_cannot_write_properties(after_tool_event): after_tool_event.invocation_state = {} with pytest.raises(AttributeError, match="Property exception is not writable"): after_tool_event.exception = Exception("test") + + +def test_after_invocation_event_properties_not_writable(agent): + """Test that properties are not writable after initialization.""" + mock_message: Message = {"role": "assistant", "content": [{"text": "test"}]} + mock_result = AgentResult( + stop_reason="end_turn", + message=mock_message, + metrics={}, + state={}, + ) + + event = AfterInvocationEvent(agent=agent, result=None) + + with pytest.raises(AttributeError, match="Property result is not writable"): + event.result = mock_result + + with pytest.raises(AttributeError, match="Property agent is not writable"): + event.agent = Mock() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 32266c3eb..d82329e95 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -147,7 +147,7 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_use): """Verify that the correct hook events are emitted as part of __call__.""" - agent("test message") + result = agent("test message") length, events = hook_provider.get_events() @@ -197,7 +197,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == AfterInvocationEvent(agent=agent, result=result) assert len(agent.messages) == 4 @@ -210,8 +210,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] # iterate the rest - async for _ in iterator: - pass + result = None + async for item in iterator: + if "result" in item: + result = item["result"] length, events = hook_provider.get_events() @@ -261,7 +263,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == AfterInvocationEvent(agent=agent, result=result) assert len(agent.messages) == 4 From a64a8513c61898446bc881ba0fcb05693be6b705 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 10 Dec 2025 15:09:50 -0500 Subject: [PATCH 005/279] feat(docs): Create agent.md and docs folder (#1312) --- AGENTS.md | 477 ++++++++++++++++++ CONTRIBUTING.md | 2 +- docs/HOOKS.md | 24 + .../MCP_CLIENT_ARCHITECTURE.md | 0 docs/README.md | 15 + STYLE_GUIDE.md => docs/STYLE_GUIDE.md | 0 src/strands/hooks/rules.md | 21 - 7 files changed, 517 insertions(+), 22 deletions(-) create mode 100644 AGENTS.md create mode 100644 docs/HOOKS.md rename _MCP_CLIENT_ARCHITECTURE.md => docs/MCP_CLIENT_ARCHITECTURE.md (100%) create mode 100644 docs/README.md rename STYLE_GUIDE.md => docs/STYLE_GUIDE.md (100%) delete mode 100644 src/strands/hooks/rules.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..49ea8a656 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,477 @@ +# AGENTS.md + +This document provides context, patterns, and guidelines for AI coding assistants working in this repository. For human contributors, see [CONTRIBUTING.md](./CONTRIBUTING.md). + +## Product Overview + +Strands Agents is an open-source Python SDK for building AI agents with a model-driven approach. It provides a lightweight, flexible framework that scales from simple conversational assistants to complex autonomous workflows. + +**Core Features:** +- Model Agnostic: Multiple model providers (Amazon Bedrock, Anthropic, OpenAI, Gemini, Ollama, etc.) +- Python-Based Tools: Simple `@tool` decorator with hot reloading +- MCP Integration: Native Model Context Protocol support +- Multi-Agent Systems: Agent-to-agent, swarms, and graph patterns +- Streaming Support: Real-time response streaming +- Hooks: Event-driven extensibility for agent lifecycle +- Session Management: Pluggable session managers (file, S3, custom) +- Observability: OpenTelemetry tracing and metrics + +## Directory Structure + +``` +strands-agents/ +│ +├── src/strands/ # Main package source code +│ ├── agent/ # Core agent implementation +│ │ ├── agent.py # Main Agent class +│ │ ├── agent_result.py # Agent execution results +│ │ ├── state.py # Agent state management +│ │ └── conversation_manager/ # Message history strategies +│ │ ├── conversation_manager.py # Base conversation manager +│ │ ├── null_conversation_manager.py # No-op manager +│ │ ├── sliding_window_conversation_manager.py # Window-based +│ │ └── summarizing_conversation_manager.py # Summarization-based +│ │ +│ ├── event_loop/ # Agent execution loop +│ │ ├── event_loop.py # Main loop logic +│ │ ├── streaming.py # Streaming response handling +│ │ └── _recover_message_on_max_tokens_reached.py +│ │ +│ ├── models/ # Model provider implementations +│ │ ├── model.py # Base model interface +│ │ ├── bedrock.py # Amazon Bedrock +│ │ ├── anthropic.py # Anthropic Claude +│ │ ├── openai.py # OpenAI +│ │ ├── gemini.py # Google Gemini +│ │ ├── ollama.py # Ollama local models +│ │ ├── litellm.py # LiteLLM unified interface +│ │ ├── mistral.py # Mistral AI +│ │ ├── llamaapi.py # LlamaAPI +│ │ ├── llamacpp.py # llama.cpp local +│ │ ├── sagemaker.py # AWS SageMaker +│ │ ├── writer.py # Writer AI +│ │ └── _validation.py # Validation utilities +│ │ +│ ├── tools/ # Tool system +│ │ ├── decorator.py # @tool decorator +│ │ ├── tools.py # Tool base classes +│ │ ├── registry.py # Tool registration +│ │ ├── loader.py # Dynamic tool loading +│ │ ├── watcher.py # Hot reload +│ │ ├── _caller.py # Tool invocation +│ │ ├── _validator.py # Tool validation +│ │ ├── _tool_helpers.py # Helper utilities +│ │ ├── executors/ # Tool execution environments +│ │ │ ├── _executor.py # Base executor +│ │ │ ├── concurrent.py # Thread/process pool +│ │ │ └── sequential.py # Sequential execution +│ │ ├── mcp/ # Model Context Protocol +│ │ │ ├── mcp_client.py # MCP client implementation +│ │ │ ├── mcp_agent_tool.py # MCP tool wrapper +│ │ │ ├── mcp_types.py # MCP type definitions +│ │ │ └── mcp_instrumentation.py # MCP telemetry +│ │ └── structured_output/ # Structured output handling +│ │ ├── structured_output_tool.py +│ │ ├── structured_output_utils.py +│ │ └── _structured_output_context.py +│ │ +│ ├── multiagent/ # Multi-agent patterns +│ │ ├── base.py # Base multi-agent classes +│ │ ├── graph.py # Graph-based orchestration +│ │ ├── swarm.py # Swarm pattern +│ │ ├── a2a/ # Agent-to-agent protocol +│ │ │ ├── executor.py # A2A executor +│ │ │ └── server.py # A2A server +│ │ └── nodes/ # Graph node implementations +│ │ +│ ├── types/ # Type definitions +│ │ ├── content.py # Content types (text, images, etc.) +│ │ ├── tools.py # Tool-related types +│ │ ├── streaming.py # Streaming event types +│ │ ├── exceptions.py # Custom exceptions +│ │ ├── agent.py # Agent types +│ │ ├── session.py # Session types +│ │ ├── multiagent.py # Multi-agent types +│ │ ├── guardrails.py # Guardrail types +│ │ ├── interrupt.py # Interrupt types +│ │ ├── media.py # Media types +│ │ ├── citations.py # Citation types +│ │ ├── traces.py # Trace types +│ │ ├── event_loop.py # Event loop types +│ │ ├── json_dict.py # JSON dict utilities +│ │ ├── collections.py # Collection types +│ │ ├── _events.py # Internal event types +│ │ └── models/ # Model-specific types +│ │ +│ ├── session/ # Session management +│ │ ├── session_manager.py # Base interface +│ │ ├── file_session_manager.py # File-based storage +│ │ ├── s3_session_manager.py # S3 storage +│ │ ├── repository_session_manager.py # Repository pattern +│ │ └── session_repository.py # Storage interface +│ │ +│ ├── telemetry/ # Observability (OpenTelemetry) +│ │ ├── tracer.py # Tracing +│ │ ├── metrics.py # Metrics collection +│ │ ├── metrics_constants.py # Metric definitions +│ │ └── config.py # Configuration +│ │ +│ ├── hooks/ # Event hooks system +│ │ ├── events.py # Hook event definitions +│ │ └── registry.py # Hook registration +│ │ +│ ├── handlers/ # Event handlers +│ │ └── callback_handler.py # Callback handling +│ │ +│ ├── experimental/ # Experimental features (API may change) +│ │ ├── agent_config.py # Experimental agent config +│ │ ├── bidi/ # Bidirectional streaming +│ │ │ ├── agent/ # Bidi agent implementation +│ │ │ ├── io/ # Input/output handling +│ │ │ ├── models/ # Bidi model providers +│ │ │ ├── tools/ # Bidi tools +│ │ │ ├── types/ # Bidi types +│ │ │ └── _async/ # Async utilities +│ │ ├── hooks/ # Experimental hooks +│ │ │ ├── events.py +│ │ │ └── multiagent/ +│ │ ├── steering/ # Agent steering +│ │ │ ├── context_providers/ +│ │ │ ├── core/ +│ │ │ └── handlers/ +│ │ └── tools/ # Experimental tools +│ │ └── tool_provider.py +│ │ +│ ├── __init__.py # Public API exports +│ ├── interrupt.py # Interrupt handling +│ ├── _async.py # Async utilities +│ ├── _exception_notes.py # Exception helpers +│ ├── _identifier.py # ID generation +│ └── py.typed # PEP 561 marker +│ +├── tests/ # Unit tests (mirrors src/) +│ ├── conftest.py # Pytest fixtures +│ ├── fixtures/ # Test fixtures +│ │ ├── mocked_model_provider.py # Mock model for testing +│ │ ├── mock_agent_tool.py +│ │ ├── mock_hook_provider.py +│ │ └── ... +│ └── strands/ # Tests mirror src/strands/ +│ ├── agent/ +│ ├── event_loop/ +│ ├── models/ +│ ├── tools/ +│ ├── multiagent/ +│ ├── types/ +│ ├── session/ +│ ├── telemetry/ +│ ├── hooks/ +│ ├── handlers/ +│ ├── experimental/ +│ └── utils/ +│ +├── tests_integ/ # Integration tests +│ ├── conftest.py +│ ├── models/ # Model provider tests +│ │ ├── test_model_bedrock.py +│ │ ├── test_model_anthropic.py +│ │ ├── test_model_openai.py +│ │ ├── test_model_gemini.py +│ │ ├── test_model_ollama.py +│ │ └── ... +│ ├── mcp/ # MCP integration tests +│ │ ├── test_mcp_client.py +│ │ ├── echo_server.py +│ │ └── ... +│ ├── tools/ # Tool system tests +│ ├── hooks/ # Hook tests +│ ├── interrupts/ # Interrupt tests +│ ├── steering/ # Steering tests +│ ├── bidi/ # Bidirectional streaming tests +│ ├── test_multiagent_graph.py +│ ├── test_multiagent_swarm.py +│ ├── test_stream_agent.py +│ ├── test_session.py +│ └── ... +│ +├── docs/ # Developer documentation +│ ├── README.md # Docs folder overview +│ ├── STYLE_GUIDE.md # Code style conventions +│ └── MCP_CLIENT_ARCHITECTURE.md # MCP threading architecture +│ +├── pyproject.toml # Project config (build, deps, tools) +├── AGENTS.md # This file +└── CONTRIBUTING.md # Human contributor guidelines +``` + +### Directory Purposes + +- **`src/strands/`**: All production code +- **`tests/`**: Unit tests mirroring src/ structure +- **`tests_integ/`**: Integration tests with real model providers +- **`docs/`**: Developer documentation for contributors + +**IMPORTANT**: After making changes that affect the directory structure (adding new directories, moving files, or adding significant new files), you MUST update this directory structure section to reflect the current state of the repository. + +## Development Workflow + +### 1. Environment Setup + +```bash +hatch shell # Enter dev environment +pre-commit install -t pre-commit -t commit-msg # Install hooks +``` + +### 2. Making Changes + +1. Create feature branch +2. Implement changes following the patterns below +3. Run quality checks before committing +4. Commit with conventional commits (`feat:`, `fix:`, `docs:`, `refactor:`, `test:`, `chore:`) +5. Push and open PR + +### 3. Quality Gates + +Pre-commit hooks run automatically on commit: +- Formatting (ruff) +- Linting (ruff + mypy) +- Tests (pytest) +- Commit message validation (commitizen) + +All checks must pass before commit is allowed. + +## Coding Patterns and Best Practices + +### Logging Style + +Use structured logging with field-value pairs followed by human-readable messages: + +```python +logger.debug("field1=<%s>, field2=<%s> | human readable message", field1, field2) +``` + +**Guidelines:** +- Add context as `FIELD=` pairs at the beginning +- Separate pairs with commas +- Enclose values in `<>` for readability (especially for empty values) +- Use `%s` string interpolation (not f-strings) for performance +- Use lowercase messages, no punctuation +- Separate multiple statements with pipe `|` + +**Good:** +```python +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) +``` + +**Bad:** +```python +logger.debug(f"User {user_id} performed action {action}") # Don't use f-strings +logger.info("Request completed in %d ms.", duration) # Don't add punctuation +``` + +### Type Annotations + +All code must include type annotations: +- Function parameters and return types required +- No implicit optional types +- Use `typing` or `typing_extensions` for complex types +- Mypy strict mode enforced + +```python +def process_message(content: str, max_tokens: int | None = None) -> AgentResult: + ... +``` + +### Docstrings + +Use Google-style docstrings for all public functions, classes, and modules: + +```python +def example_function(param1: str, param2: int) -> bool: + """Brief description of function. + + Longer description if needed. This docstring is used by LLMs + to understand the function's purpose when used as a tool. + + Args: + param1: Description of param1 + param2: Description of param2 + + Returns: + Description of return value + + Raises: + ValueError: When invalid input is provided + """ + pass +``` + +### Import Organization + +Imports must be at the top of the file. + +Imports are automatically organized by ruff/isort: +1. Standard library imports +2. Third-party imports +3. Local application imports + +Use absolute imports for cross-package references, relative imports within packages. + +```python +# Standard library +import logging +from typing import Any + +# Third-party +import boto3 +from pydantic import BaseModel + +# Local +from strands.agent import Agent +from .tools import Tool +``` + +### File Organization + +- Each major feature in its own directory +- Base classes and interfaces defined first +- Implementation-specific code in separate files +- Private modules prefixed with `_` +- Test files prefixed with `test_` + +### Naming Conventions + +- **Variables/Functions**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_SNAKE_CASE` +- **Private members**: Prefix with `_` + +### Error Handling + +- Use custom exceptions from `strands.types.exceptions` +- Provide clear error messages with context +- Don't swallow exceptions silently + +## Testing Patterns + +### Unit Tests (`tests/`) + +- Mirror the `src/strands/` structure exactly +- Focus on isolated component testing +- Use mocking for external dependencies (models, AWS services) +- Use fixtures from `tests/fixtures/` (e.g., `mocked_model_provider.py`) + +```python +# tests/strands/agent/test_agent.py mirrors src/strands/agent/agent.py +``` + +### Integration Tests (`tests_integ/`) + +- End-to-end testing with real model providers +- Require credentials/API keys (set via environment variables) +- Organized by feature area + +### Test File Naming + +- Unit tests: `test_{module}.py` in `tests/strands/{path}/` +- Integration tests: `test_{feature}.py` in `tests_integ/` + +### Running Tests + +```bash +hatch test # Run unit tests +hatch test -c # Run with coverage +hatch run test-integ # Run integration tests +hatch test tests/strands/agent/ # Run specific directory +hatch test --all # Test all Python versions (3.10-3.13) +``` + +### Writing Tests + +- Use pytest fixtures for setup/teardown +- Use `moto` for mocking AWS services +- Use `pytest.mark.asyncio` for async tests +- Keep tests focused and independent + +## Things to Do + +- Use explicit return types for all functions +- Write Google-style docstrings for public APIs +- Use structured logging format +- Add type annotations everywhere +- Use relative imports within packages +- Mirror src/ structure in tests/ +- Run `hatch fmt --formatter` and `hatch fmt --linter` before committing +- Follow conventional commits (`feat:`, `fix:`, `docs:`, etc.) + +## Things NOT to Do + +- Don't use f-strings in logging calls +- Don't use `Any` type without good reason +- Don't skip type annotations +- Don't put unit tests outside `tests/strands/` structure +- Don't commit without running pre-commit hooks +- Don't add punctuation to log messages +- Don't use implicit optional types + +## Development Commands + +```bash +# Environment +hatch shell # Enter dev environment + +# Formatting & Linting +hatch fmt --formatter # Format code +hatch fmt --linter # Run linters (ruff + mypy) + +# Testing +hatch test # Run unit tests +hatch test -c # Run with coverage +hatch run test-integ # Run integration tests +hatch test --all # Test all Python versions + +# Pre-commit +pre-commit run --all-files # Run all hooks manually + +# Readiness Check +hatch run prepare # Run all checks (format, lint, test) + +# Build +hatch build # Build package +``` + +## Agent-Specific Notes + +### Writing Code + +- Make the SMALLEST reasonable changes to achieve the desired outcome +- Prefer simple, clean, maintainable solutions over clever ones +- Reduce code duplication, even if refactoring takes extra effort +- Match the style and formatting of surrounding code +- Fix broken things immediately when you find them + +### Code Comments + +- Comments should explain WHAT the code does or WHY it exists +- NEVER add comments about what used to be there or how something changed +- NEVER refer to temporal context ("recently refactored", "moved") +- Keep comments concise and evergreen + +### Code Review Considerations + +- Address all review comments +- Test changes thoroughly +- Update documentation if behavior changes +- Maintain test coverage +- Follow conventional commit format for fix commits + +## Additional Resources + +- [Strands Agents Documentation](https://strandsagents.com/) +- [CONTRIBUTING.md](./CONTRIBUTING.md) - Human contributor guidelines +- [docs/](./docs/) - Developer documentation + - [STYLE_GUIDE.md](./docs/STYLE_GUIDE.md) - Code style conventions + - [HOOKS.md](./docs/HOOKS.md) - Hooks system guide + - [MCP_CLIENT_ARCHITECTURE.md](./docs/MCP_CLIENT_ARCHITECTURE.md) - MCP threading design diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index be83ff85b..0e01fc38d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -122,7 +122,7 @@ hatch fmt --linter If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. -For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). +For additional details on styling, please see our dedicated [Style Guide](./docs/STYLE_GUIDE.md). ## Contributing via Pull Requests diff --git a/docs/HOOKS.md b/docs/HOOKS.md new file mode 100644 index 000000000..b447c6400 --- /dev/null +++ b/docs/HOOKS.md @@ -0,0 +1,24 @@ +# Hooks System + +The hooks system enables extensible agent functionality through strongly-typed event callbacks. + +## Terminology + +- **Paired events**: Events that denote the beginning and end of an operation +- **Hook callback**: A function that receives a strongly-typed event argument +- **Hook provider**: An object implementing `HookProvider` that registers callbacks via `register_hooks()` + +## Naming Conventions + +- All hook events have a suffix of `Event` +- Paired events follow `Before{Action}Event` and `After{Action}Event` +- Action words come after the lifecycle indicator (e.g., `BeforeToolCallEvent` not `BeforeToolEvent`) + +## Paired Events + +- For every `Before` event there is a corresponding `After` event, even if an exception occurs +- `After` events invoke callbacks in reverse registration order (for proper cleanup) + +## Writable Properties + +Some events have writable properties that modify agent behavior. Values are re-read after callbacks complete. For example, `BeforeToolCallEvent.selected_tool` is writable - after invoking the callback, the modified `selected_tool` takes effect for the tool call. diff --git a/_MCP_CLIENT_ARCHITECTURE.md b/docs/MCP_CLIENT_ARCHITECTURE.md similarity index 100% rename from _MCP_CLIENT_ARCHITECTURE.md rename to docs/MCP_CLIENT_ARCHITECTURE.md diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..4ad4ee44f --- /dev/null +++ b/docs/README.md @@ -0,0 +1,15 @@ +# Developer Documentation + +This folder contains documentation for contributors and developers working on the Strands Agents SDK. + +## Contents + +- [STYLE_GUIDE.md](./STYLE_GUIDE.md) - Code style conventions and formatting guidelines +- [HOOKS.md](./HOOKS.md) - Hooks system rules and usage guide +- [MCP_CLIENT_ARCHITECTURE.md](./MCP_CLIENT_ARCHITECTURE.md) - MCP client threading architecture and design decisions + +## Related Documentation + +- [AGENTS.md](../AGENTS.md) - Guidance for AI agents and LLMs working with this codebase +- [CONTRIBUTING.md](../CONTRIBUTING.md) - Contribution guidelines for human contributors +- [strandsagents.com](https://strandsagents.com/) - User-facing documentation diff --git a/STYLE_GUIDE.md b/docs/STYLE_GUIDE.md similarity index 100% rename from STYLE_GUIDE.md rename to docs/STYLE_GUIDE.md diff --git a/src/strands/hooks/rules.md b/src/strands/hooks/rules.md deleted file mode 100644 index 4d0f571c6..000000000 --- a/src/strands/hooks/rules.md +++ /dev/null @@ -1,21 +0,0 @@ -# Hook System Rules - -## Terminology - -- **Paired events**: Events that denote the beginning and end of an operation -- **Hook callback**: A function that receives a strongly-typed event argument and performs some action in response - -## Naming Conventions - -- All hook events have a suffix of `Event` -- Paired events follow the naming convention of `Before{Item}Event` and `After{Item}Event` -- Pre actions in the name. i.e. prefer `BeforeToolCallEvent` over `BeforeToolEvent`. - -## Paired Events - -- The final event in a pair returns `True` for `should_reverse_callbacks` -- For every `Before` event there is a corresponding `After` event, even if an exception occurs - -## Writable Properties - -For events with writable properties, those values are re-read after invoking the hook callbacks and used in subsequent processing. For example, `BeforeToolEvent.selected_tool` is writable - after invoking the callback for `BeforeToolEvent`, the `selected_tool` takes effect for the tool call. \ No newline at end of file From 60bd2914781099259562b0a19b45cc552c6fa6ed Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 11 Dec 2025 16:41:51 -0500 Subject: [PATCH 006/279] bidi - remove python 3.11+ features (#1302) --- .../experimental/bidi/_async/__init__.py | 9 +-- .../experimental/bidi/_async/_task_group.py | 61 +++++++++++++++++++ src/strands/experimental/bidi/agent/agent.py | 4 +- .../experimental/bidi/_async/test__init__.py | 12 ++-- .../bidi/_async/test_task_group.py | 59 ++++++++++++++++++ .../bidi/models/test_gemini_live.py | 2 +- .../bidi/models/test_openai_realtime.py | 2 +- 7 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 src/strands/experimental/bidi/_async/_task_group.py create mode 100644 tests/strands/experimental/bidi/_async/test_task_group.py diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..47473115c 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,9 +2,10 @@ from typing import Awaitable, Callable +from ._task_group import _TaskGroup from ._task_pool import _TaskPool -__all__ = ["_TaskPool"] +__all__ = ["_TaskGroup", "_TaskPool"] async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: @@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: try: await func() except Exception as exception: - exceptions.append(exception) + exceptions.append({"func_name": func.__name__, "exception": repr(exception)}) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + raise RuntimeError(f"exceptions={exceptions} | failed stop sequence") diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py new file mode 100644 index 000000000..26c67326d --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -0,0 +1,61 @@ +"""Manage a group of async tasks. + +This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11. + +- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskGroup: + """Shim of asyncio.TaskGroup for use in Python 3.10. + + Attributes: + _tasks: List of tasks in group. + """ + + _tasks: list[asyncio.Task] + + def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create an async task and add to group. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + self._tasks.append(task) + return task + + async def __aenter__(self) -> "_TaskGroup": + """Setup self managed task group context.""" + self._tasks = [] + return self + + async def __aexit__(self, *_: Any) -> None: + """Execute tasks in group. + + The following execution rules are enforced: + - The context stops executing all tasks if at least one task raises an Exception or the context is cancelled. + - The context re-raises Exceptions to the caller. + - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. + """ + try: + await asyncio.gather(*self._tasks) + + except (Exception, asyncio.CancelledError) as error: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + if not isinstance(error, asyncio.CancelledError): + raise + + context_task = asyncio.current_task() + if context_task and context_task.cancelling() > 0: # context itself was cancelled + raise + + finally: + self._tasks = [] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4012d5e2d..5ddb181ea 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,7 +30,7 @@ from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider -from .._async import stop_all +from .._async import _TaskGroup, stop_all from ..models.model import BidiModel from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -390,7 +390,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: + async with _TaskGroup() as task_group: inputs_task = task_group.create_task(run_inputs()) task_group.create_task(run_outputs(inputs_task)) diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..a121ddecc 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -10,17 +10,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(Exception, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_message = str(exc_info.value) + assert "ValueError('stop 2 failed')" in tru_message + assert "ValueError('stop 4 failed')" in tru_message @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py new file mode 100644 index 000000000..23ff821f9 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -0,0 +1,59 @@ +import asyncio +import unittest.mock + +import pytest + +from strands.experimental.bidi._async._task_group import _TaskGroup + + +@pytest.mark.asyncio +async def test_task_group__aexit__(): + coro = unittest.mock.AsyncMock() + + async with _TaskGroup() as task_group: + task_group.create_task(coro()) + + coro.assert_called_once() + + +@pytest.mark.asyncio +async def test_task_group__aexit__exception(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + async def fail(): + raise ValueError("test error") + + with pytest.raises(ValueError, match=r"test error"): + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + fail_task = task_group.create_task(fail()) + + assert wait_task.cancelled() + assert not fail_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__cancelled(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + tasks = [] + + run_event = asyncio.Event() + async def run(): + async with _TaskGroup() as task_group: + tasks.append(task_group.create_task(wait())) + run_event.set() + + run_task = asyncio.create_task(run()) + await run_event.wait() + run_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await run_task + + wait_task = tasks[0] + assert wait_task.cancelled() diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 79bb29d41..3a9d7e3dc 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index a973e80aa..1cabbc92b 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -353,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() From 2a02388493fab20dd2ad816a15d815d37f90e1c9 Mon Sep 17 00:00:00 2001 From: David Padbury Date: Mon, 15 Dec 2025 09:52:09 -0600 Subject: [PATCH 007/279] fix(mcp): close mcp client event loop (#1321) --------- Co-authored-by: David Padbury Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 3 +++ tests/strands/tools/mcp/test_mcp_client.py | 27 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index bb5dca19c..7a26cdd6b 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -330,6 +330,9 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + if self._background_thread_event_loop is not None: + self._background_thread_event_loop.close() + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index ec77b48a2..e72aebd92 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -524,6 +524,33 @@ def test_stop_with_background_thread_but_no_event_loop(): assert client._background_thread is None +def test_stop_closes_event_loop(): + """Test that stop() properly closes the event loop when it exists.""" + client = MCPClient(MagicMock()) + + # Mock a background thread with event loop + mock_thread = MagicMock() + mock_thread.join = MagicMock() + mock_event_loop = MagicMock() + mock_event_loop.close = MagicMock() + + client._background_thread = mock_thread + client._background_thread_event_loop = mock_event_loop + + # Should close the event loop and join the thread + client.stop(None, None, None) + + # Verify thread was joined + mock_thread.join.assert_called_once() + + # Verify event loop was closed + mock_event_loop.close.assert_called_once() + + # Verify cleanup occurred + assert client._background_thread is None + assert client._background_thread_event_loop is None + + def test_mcp_client_state_reset_after_timeout(): """Test that all client state is properly reset after timeout.""" From d6284a6030dadf642245171436c91acf358d0d46 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:05:16 -0500 Subject: [PATCH 008/279] Add issue-responder action (#1319) --- .github/workflows/issue-responder.yml | 66 +++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 .github/workflows/issue-responder.yml diff --git a/.github/workflows/issue-responder.yml b/.github/workflows/issue-responder.yml new file mode 100644 index 000000000..318b74361 --- /dev/null +++ b/.github/workflows/issue-responder.yml @@ -0,0 +1,66 @@ +name: Issue Responder + +on: + issues: + types: [opened] + +permissions: + id-token: write + contents: read + +jobs: + respond-to-issue: + runs-on: ubuntu-latest + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_AGENTCORE_ACTIONS_ROLE }} + aws-region: us-west-2 + - name: Invoke AgentCore with issue details + env: + GH_ISSUE_AGENTCORE_RUNTIME_ARN: ${{ secrets.GH_ISSUE_AGENTCORE_RUNTIME_ARN }} + ISSUE_NUMBER: ${{ github.event.issue.number }} + ISSUE_TITLE: ${{ github.event.issue.title }} + ISSUE_BODY: ${{ github.event.issue.body }} + ISSUE_URL: ${{ github.event.issue.html_url }} + ISSUE_AUTHOR: ${{ github.event.issue.user.login }} + REPO: ${{ github.repository }} + run: | + npm install @aws-sdk/client-bedrock-agentcore + node - <<'JSEOF' + const { BedrockAgentCoreClient, InvokeAgentRuntimeCommand } = require("@aws-sdk/client-bedrock-agentcore"); + + const payload = JSON.stringify({ + source: "github", + action: "issue_opened", + issue: { + number: parseInt(process.env.ISSUE_NUMBER), + title: process.env.ISSUE_TITLE, + body: process.env.ISSUE_BODY, + url: process.env.ISSUE_URL, + author: process.env.ISSUE_AUTHOR, + repo: process.env.REPO + } + }); + + console.log("Invoking AgentCore with payload:"); + console.log(JSON.stringify(JSON.parse(payload), null, 2)); + + const client = new BedrockAgentCoreClient({ region: "us-west-2" }); + + const sessionId = `github-issue-${process.env.ISSUE_NUMBER}-${Date.now()}-${Math.random().toString(36).slice(2)}`; + + const command = new InvokeAgentRuntimeCommand({ + agentRuntimeArn: process.env.GH_ISSUE_AGENTCORE_RUNTIME_ARN, + runtimeSessionId: sessionId, + payload: Buffer.from(payload) + }); + + (async () => { + const response = await client.send(command); + const textResponse = await response.response.transformToString(); + console.log("Response:", textResponse); + })(); + JSEOF From 673789781cb9f62800e10edb4804283297802fde Mon Sep 17 00:00:00 2001 From: Sean Nguyen Date: Wed, 17 Dec 2025 10:36:24 -0800 Subject: [PATCH 009/279] feat(a2a): support passing additional keyword arguments to FastAPI and Starlette constructors (#1250) * feat(a2a): allow passing additional keyword arguments to fastapi constructor * feat(a2a): allow passing additional keyword arguments to starlette constructor * update to use accept dictionary instead of kwargs and add tests. --------- Co-authored-by: Aaron Farntrog --- src/strands/multiagent/a2a/server.py | 18 ++++++++++++---- tests/strands/multiagent/a2a/test_server.py | 24 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index bbfbc824d..a9093742f 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -176,16 +176,21 @@ def agent_skills(self, skills: list[AgentSkill]) -> None: """ self._agent_skills = skills - def to_starlette_app(self) -> Starlette: + def to_starlette_app(self, *, app_kwargs: dict[str, Any] | None = None) -> Starlette: """Create a Starlette application for serving this agent via HTTP. Automatically handles path-based mounting if a mount path was derived from the http_url parameter. + Args: + app_kwargs: Additional keyword arguments to pass to the Starlette constructor. + Returns: Starlette: A Starlette application configured to serve this agent. """ - a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AStarletteApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build( + **app_kwargs or {} + ) if self.mount_path: # Create parent app and mount the A2A app at the specified path @@ -196,16 +201,21 @@ def to_starlette_app(self) -> Starlette: return a2a_app - def to_fastapi_app(self) -> FastAPI: + def to_fastapi_app(self, *, app_kwargs: dict[str, Any] | None = None) -> FastAPI: """Create a FastAPI application for serving this agent via HTTP. Automatically handles path-based mounting if a mount path was derived from the http_url parameter. + Args: + app_kwargs: Additional keyword arguments to pass to the FastAPI constructor. + Returns: FastAPI: A FastAPI application configured to serve this agent. """ - a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build() + a2a_app = A2AFastAPIApplication(agent_card=self.public_agent_card, http_handler=self.request_handler).build( + **app_kwargs or {} + ) if self.mount_path: # Create parent app and mount the A2A app at the specified path diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index 00dd164b5..647fce230 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -852,3 +852,27 @@ def test_serve_at_root_edge_cases(mock_strands_agent): ) assert server3.mount_path == "" assert server3.http_url == "http://api.example.com/v1/agents/team1/agent1/" + + +def test_to_starlette_app_with_app_kwargs(mock_strands_agent): + """Test that to_starlette_app passes app_kwargs to the Starlette constructor.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + app = a2a_agent.to_starlette_app(app_kwargs={"debug": True}) + + assert isinstance(app, Starlette) + assert app.debug is True + + +def test_to_fastapi_app_with_app_kwargs(mock_strands_agent): + """Test that to_fastapi_app passes app_kwargs to the FastAPI constructor.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + app = a2a_agent.to_fastapi_app(app_kwargs={"title": "Custom Agent Title"}) + + assert isinstance(app, FastAPI) + assert app.title == "Custom Agent Title" From bb46ab7bc07527cc23b7f57289a63c3c87ece685 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Wed, 17 Dec 2025 23:58:34 +0400 Subject: [PATCH 010/279] feat(tools): add replace method to ToolRegistry (#1182) --- src/strands/tools/registry.py | 26 ++++++++ tests/strands/tools/test_registry.py | 92 ++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 15150847d..2547aabcc 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -279,6 +279,32 @@ def register_tool(self, tool: AgentTool) -> None: list(self.dynamic_tools.keys()), ) + def replace(self, new_tool: AgentTool) -> None: + """Replace an existing tool with a new implementation. + + This performs a swap of the tool implementation in the registry. + The replacement takes effect on the next agent invocation. + + Args: + new_tool: New tool implementation. Its name must match the tool being replaced. + + Raises: + ValueError: If the tool doesn't exist. + """ + tool_name = new_tool.tool_name + + if tool_name not in self.registry: + raise ValueError(f"Cannot replace tool '{tool_name}' - tool does not exist") + + # Update main registry + self.registry[tool_name] = new_tool + + # Update dynamic_tools to match new tool's dynamic status + if new_tool.is_dynamic: + self.dynamic_tools[tool_name] = new_tool + elif tool_name in self.dynamic_tools: + del self.dynamic_tools[tool_name] + def get_tools_dirs(self) -> List[Path]: """Get all tool directory paths. diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 9ae51dcfe..d44936f3e 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -511,3 +511,95 @@ def test_validate_tool_spec_with_ref_property(): assert props["ref_field"] == {"$ref": "#/$defs/SomeType"} assert "type" not in props["ref_field"] assert "description" not in props["ref_field"] + + +def test_tool_registry_replace_existing_tool(): + """Test replacing an existing tool.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = False + old_tool.supports_hot_reload = False + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = False + + registry = ToolRegistry() + registry.register_tool(old_tool) + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + + +def test_tool_registry_replace_nonexistent_tool(): + """Test replacing a tool that doesn't exist raises ValueError.""" + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + + registry = ToolRegistry() + + with pytest.raises(ValueError, match="Cannot replace tool 'my_tool' - tool does not exist"): + registry.replace(new_tool) + + +def test_tool_registry_replace_dynamic_tool(): + """Test replacing a dynamic tool updates both registries.""" + old_tool = MagicMock() + old_tool.tool_name = "dynamic_tool" + old_tool.is_dynamic = True + old_tool.supports_hot_reload = True + + new_tool = MagicMock() + new_tool.tool_name = "dynamic_tool" + new_tool.is_dynamic = True + + registry = ToolRegistry() + registry.register_tool(old_tool) + registry.replace(new_tool) + + assert registry.registry["dynamic_tool"] == new_tool + assert registry.dynamic_tools["dynamic_tool"] == new_tool + + +def test_tool_registry_replace_dynamic_with_non_dynamic(): + """Test replacing a dynamic tool with non-dynamic tool removes from dynamic_tools.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = True + old_tool.supports_hot_reload = True + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = False + + registry = ToolRegistry() + registry.register_tool(old_tool) + + assert "my_tool" in registry.dynamic_tools + + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + assert "my_tool" not in registry.dynamic_tools + + +def test_tool_registry_replace_non_dynamic_with_dynamic(): + """Test replacing a non-dynamic tool with dynamic tool adds to dynamic_tools.""" + old_tool = MagicMock() + old_tool.tool_name = "my_tool" + old_tool.is_dynamic = False + old_tool.supports_hot_reload = False + + new_tool = MagicMock() + new_tool.tool_name = "my_tool" + new_tool.is_dynamic = True + + registry = ToolRegistry() + registry.register_tool(old_tool) + + assert "my_tool" not in registry.dynamic_tools + + registry.replace(new_tool) + + assert registry.registry["my_tool"] == new_tool + assert registry.dynamic_tools["my_tool"] == new_tool From bd17e9553c778c7e8848aa631836fa943703c679 Mon Sep 17 00:00:00 2001 From: Vamil Gandhi Date: Wed, 17 Dec 2025 15:15:15 -0500 Subject: [PATCH 011/279] feat(mcp): add meta field support to MCP tool results (#1237) * feat: add meta field support to MCP tool results Add support for the _meta field in MCP tool results to enable MCP servers to pass arbitrary metadata alongside tool outputs. This allows tracking of token usage, performance metrics, and other business-specific information. --------- Co-authored-by: Vamil Gandhi Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 2 + src/strands/tools/mcp/mcp_types.py | 8 +- tests/strands/tools/mcp/test_mcp_client.py | 26 ++++- tests_integ/mcp/echo_server.py | 12 ++- ..._client_structured_content_and_metadata.py | 95 +++++++++++++++++++ ...cp_client_structured_content_with_hooks.py | 64 ------------- 6 files changed, 138 insertions(+), 69 deletions(-) create mode 100644 tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py delete mode 100644 tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 7a26cdd6b..6ce591bc5 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -566,6 +566,8 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes if call_tool_result.structuredContent: result["structuredContent"] = call_tool_result.structuredContent + if call_tool_result.meta: + result["metadata"] = call_tool_result.meta return result diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 66eda08ae..8fbf573be 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,7 +1,7 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager -from typing import Any, Dict +from typing import Any from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback @@ -58,6 +58,10 @@ class MCPToolResult(ToolResult): structuredContent: Optional JSON object containing structured data returned by the MCP tool. This allows MCP tools to return complex data structures that can be processed programmatically by agents or other tools. + metadata: Optional arbitrary metadata returned by the MCP tool. This field allows + MCP servers to attach custom metadata to tool results (e.g., token usage, + performance metrics, or business-specific tracking information). """ - structuredContent: NotRequired[Dict[str, Any]] + structuredContent: NotRequired[dict[str, Any]] + metadata: NotRequired[dict[str, Any]] diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e72aebd92..f5040de1b 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -533,7 +533,7 @@ def test_stop_closes_event_loop(): mock_thread.join = MagicMock() mock_event_loop = MagicMock() mock_event_loop.close = MagicMock() - + client._background_thread = mock_thread client._background_thread_event_loop = mock_event_loop @@ -542,7 +542,7 @@ def test_stop_closes_event_loop(): # Verify thread was joined mock_thread.join.assert_called_once() - + # Verify event loop was closed mock_event_loop.close.assert_called_once() @@ -750,3 +750,25 @@ async def test_handle_error_message_non_exception(): # This should not raise an exception await client._handle_error_message("normal message") + + +def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_session): + """Test that call_tool_sync correctly handles both meta and structuredContent fields.""" + mock_content = MCPTextContent(type="text", text="Test message") + metadata = {"tokenUsage": {"inputTokens": 100, "outputTokens": 50}} + structured_content = {"result": 42, "status": "completed"} + mock_session.call_tool.return_value = MCPCallToolResult( + isError=False, content=[mock_content], _meta=metadata, structuredContent=structured_content + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + + assert result["status"] == "success" + assert result["toolUseId"] == "test-123" + assert "metadata" in result + assert result["metadata"] == metadata + assert "structuredContent" in result + assert result["structuredContent"] == structured_content diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index e15065a4a..a23a87b5c 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -19,7 +19,7 @@ from typing import Literal from mcp.server import FastMCP -from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents +from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents from pydantic import BaseModel @@ -50,6 +50,16 @@ def echo(to_echo: str) -> str: def echo_with_structured_content(to_echo: str) -> EchoResponse: return EchoResponse(echoed=to_echo, message_length=len(to_echo)) + @mcp.tool(description="Echos response back with metadata") + def echo_with_metadata(to_echo: str): + """Return structured content and metadata in the tool result.""" + + return CallToolResult( + content=[TextContent(type="text", text=to_echo)], + isError=False, + _meta={"metadata": {"nested": 1}, "shallow": "val"}, + ) + @mcp.tool(description="Get current weather information for a location") def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): """Get weather data including forecasts and alerts for the specified location""" diff --git a/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py new file mode 100644 index 000000000..3e6132b38 --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_structured_content_and_metadata.py @@ -0,0 +1,95 @@ +"""Integration test for MCP client structured content and metadata support. + +This test verifies that MCP tools can return structured content and metadata, +and that the MCP client properly handles and exposes these fields in tool results. +""" + +import json + +from mcp import StdioServerParameters, stdio_client + +from strands import Agent +from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry +from strands.tools.mcp.mcp_client import MCPClient + + +class ToolResultCapture(HookProvider): + """Captures tool results for inspection.""" + + def __init__(self): + self.captured_results = {} + + def register_hooks(self, registry: HookRegistry) -> None: + """Register callback for after tool invocation events.""" + registry.add_callback(AfterToolCallEvent, self.on_after_tool_invocation) + + def on_after_tool_invocation(self, event: AfterToolCallEvent) -> None: + """Capture tool results by tool name.""" + tool_name = event.tool_use["name"] + self.captured_results[tool_name] = event.result + + +def test_structured_content(): + """Test that MCP tools can return structured content.""" + # Set up result capture + result_capture = ToolResultCapture() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and result capture + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[result_capture]) + + # Test structured content functionality + test_data = "STRUCTURED_TEST" + agent(f"Use the echo_with_structured_content tool to echo: {test_data}") + + # Verify result was captured + assert "echo_with_structured_content" in result_capture.captured_results + result = result_capture.captured_results["echo_with_structured_content"] + + # Verify basic result structure + assert result["status"] == "success" + assert len(result["content"]) == 1 + + # Verify structured content is present and correct + assert "structuredContent" in result + assert result["structuredContent"] == {"echoed": test_data, "message_length": 15} + + # Verify text content matches structured content + text_content = json.loads(result["content"][0]["text"]) + assert text_content == {"echoed": test_data, "message_length": 15} + + +def test_metadata(): + """Test that MCP tools can return metadata.""" + # Set up result capture + result_capture = ToolResultCapture() + + # Set up MCP client for echo server + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + # Create agent with MCP tools and result capture + agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[result_capture]) + + # Test metadata functionality + test_data = "METADATA_TEST" + agent(f"Use the echo_with_metadata tool to echo: {test_data}") + + # Verify result was captured + assert "echo_with_metadata" in result_capture.captured_results + result = result_capture.captured_results["echo_with_metadata"] + + # Verify basic result structure + assert result["status"] == "success" + + # Verify metadata is present and correct + assert "metadata" in result + expected_metadata = {"metadata": {"nested": 1}, "shallow": "val"} + assert result["metadata"] == expected_metadata diff --git a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py b/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py deleted file mode 100644 index ef4993b05..000000000 --- a/tests_integ/mcp/test_mcp_client_structured_content_with_hooks.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Integration test demonstrating hooks system with MCP client structured content tool. - -This test shows how to use the hooks system to capture and inspect tool invocation -results, specifically testing the echo_with_structured_content tool from echo_server. -""" - -import json - -from mcp import StdioServerParameters, stdio_client - -from strands import Agent -from strands.hooks import AfterToolCallEvent, HookProvider, HookRegistry -from strands.tools.mcp.mcp_client import MCPClient - - -class StructuredContentHookProvider(HookProvider): - """Hook provider that captures structured content tool results.""" - - def __init__(self): - self.captured_result = None - - def register_hooks(self, registry: HookRegistry) -> None: - """Register callback for after tool invocation events.""" - registry.add_callback(AfterToolCallEvent, self.on_after_tool_invocation) - - def on_after_tool_invocation(self, event: AfterToolCallEvent) -> None: - """Capture structured content tool results.""" - if event.tool_use["name"] == "echo_with_structured_content": - self.captured_result = event.result - - -def test_mcp_client_hooks_structured_content(): - """Test using hooks to inspect echo_with_structured_content tool result.""" - # Create hook provider to capture tool result - hook_provider = StructuredContentHookProvider() - - # Set up MCP client for echo server - stdio_mcp_client = MCPClient( - lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) - ) - - with stdio_mcp_client: - # Create agent with MCP tools and hook provider - agent = Agent(tools=stdio_mcp_client.list_tools_sync(), hooks=[hook_provider]) - - # Test structured content functionality - test_data = "HOOKS_TEST_DATA" - agent(f"Use the echo_with_structured_content tool to echo: {test_data}") - - # Verify hook captured the tool result - assert hook_provider.captured_result is not None - result = hook_provider.captured_result - - # Verify basic result structure - assert result["status"] == "success" - assert len(result["content"]) == 1 - - # Verify structured content is present and correct - assert "structuredContent" in result - assert result["structuredContent"] == {"echoed": test_data, "message_length": 15} - - # Verify text content matches structured content - text_content = json.loads(result["content"][0]["text"]) - assert text_content == {"echoed": test_data, "message_length": 15} From 583b10e48d4bd0b7f83d1d47d214bbbefc70cf80 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Thu, 18 Dec 2025 19:24:16 +0400 Subject: [PATCH 012/279] style: remove redundant None from dict.get() calls (#956) Removes explicit 'None' default values from '.get()' calls in 'agent.py' and 'sagemaker.py'. The 'dict.get()' method defaults to 'None' already, so this change makes the code more concise and idiomatic without changing functionality. --- src/strands/models/sagemaker.py | 18 +++++++++--------- src/strands/tools/_caller.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 7f8b8ff51..1fe630fdc 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -353,7 +353,7 @@ async def stream( logger.info("choice=<%s>", json.dumps(choice, indent=2)) # Handle text content - if choice["delta"].get("content", None): + if choice["delta"].get("content"): if not text_content_started: yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) text_content_started = True @@ -367,7 +367,7 @@ async def stream( ) # Handle reasoning content - if choice["delta"].get("reasoning_content", None): + if choice["delta"].get("reasoning_content"): if not reasoning_content_started: yield self.format_chunk( {"chunk_type": "content_start", "data_type": "reasoning_content"} @@ -392,7 +392,7 @@ async def stream( finish_reason = choice["finish_reason"] break - if choice.get("usage", None): + if choice.get("usage"): yield self.format_chunk( {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])} ) @@ -412,7 +412,7 @@ async def stream( # Handle tool calling logger.info("tool_calls=<%s>", json.dumps(tool_calls, indent=2)) for tool_deltas in tool_calls.values(): - if not tool_deltas[0]["function"].get("name", None): + if not tool_deltas[0]["function"].get("name"): raise Exception("The model did not provide a tool name.") yield self.format_chunk( {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])} @@ -453,7 +453,7 @@ async def stream( yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) # Handle reasoning content - if message.get("reasoning_content", None): + if message.get("reasoning_content"): yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"}) yield self.format_chunk( { @@ -465,7 +465,7 @@ async def stream( yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"}) # Handle the tool calling, if any - if message.get("tool_calls", None) or message_stop_reason == "tool_calls": + if message.get("tool_calls") or message_stop_reason == "tool_calls": if not isinstance(message["tool_calls"], list): message["tool_calls"] = [message["tool_calls"]] for tool_call in message["tool_calls"]: @@ -484,9 +484,9 @@ async def stream( # Message close yield self.format_chunk({"chunk_type": "message_stop", "data": message_stop_reason}) # Handle usage metadata - if final_response_json.get("usage", None): + if final_response_json.get("usage"): yield self.format_chunk( - {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage", None))} + {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json.get("usage"))} ) except ( self.client.exceptions.InternalFailure, @@ -556,7 +556,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> "thinking": content["reasoningContent"].get("reasoningText", {}).get("text", ""), "type": "thinking", } - elif not content.get("reasoningContent", None): + elif not content.get("reasoningContent"): content.pop("reasoningContent", None) if "video" in content: diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 1e0ca2c8d..97485d068 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -120,7 +120,7 @@ def _find_normalized_tool_name(self, name: str) -> str: """Lookup the tool represented by name, replacing characters with underscores as necessary.""" tool_registry = self._agent.tool_registry.registry - if tool_registry.get(name, None): + if tool_registry.get(name): return name # If the desired name contains underscores, it might be a placeholder for characters that can't be From 82f5bcf24c871b2ce8f626d4c523dc311e7f71d6 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 18 Dec 2025 12:30:39 -0500 Subject: [PATCH 013/279] chore: Expose Status from .base for easier imports (#1356) Currently if you want to check the Status using the enum, you have to import from strands.multiagent.base which is a little odd, so expose it at that level. Co-authored-by: Mackenzie Zastrow --- src/strands/multiagent/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/__init__.py b/src/strands/multiagent/__init__.py index e251e9318..ad99944a8 100644 --- a/src/strands/multiagent/__init__.py +++ b/src/strands/multiagent/__init__.py @@ -8,7 +8,7 @@ standardized communication between agents. """ -from .base import MultiAgentBase, MultiAgentResult +from .base import MultiAgentBase, MultiAgentResult, Status from .graph import GraphBuilder, GraphResult from .swarm import Swarm, SwarmResult @@ -17,6 +17,7 @@ "GraphResult", "MultiAgentBase", "MultiAgentResult", + "Status", "Swarm", "SwarmResult", ] From 1792ddb1285d9809321c1a5cbe331fa0fc98240d Mon Sep 17 00:00:00 2001 From: Eric Zhu <73148494+ericfzhu@users.noreply.github.com> Date: Fri, 19 Dec 2025 04:42:04 +1100 Subject: [PATCH 014/279] fix(bedrock): CitationLocation is UnionType, and correctly joining citation chunks when streaming is being used (#1341) --------- Co-authored-by: Eric Zhu Co-authored-by: Dean Schmigelski --- src/strands/event_loop/streaming.py | 7 +- src/strands/models/bedrock.py | 11 +- src/strands/types/_events.py | 2 +- src/strands/types/citations.py | 15 +- tests/strands/event_loop/test_streaming.py | 227 ++++++++++++++++++++- tests/strands/models/test_bedrock.py | 73 +++++++ tests/strands/types/test__events.py | 4 +- tests_integ/models/test_model_bedrock.py | 6 + 8 files changed, 323 insertions(+), 22 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..804f90a1d 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -289,12 +289,13 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: state["current_tool_use"] = {} elif text: - content.append({"text": text}) - state["text"] = "" if citations_content: - citations_block: CitationsContentBlock = {"citations": citations_content} + citations_block: CitationsContentBlock = {"citations": citations_content, "content": [{"text": text}]} content.append({"citationsContent": citations_block}) state["citationsContent"] = [] + else: + content.append({"text": text}) + state["text"] = "" elif reasoning_text: content_block: ContentBlock = { diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..08d8f400c 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -500,16 +500,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An for citation in citations["citations"]: filtered_citation: dict[str, Any] = {} if "location" in citation: - location = citation["location"] - filtered_location = {} - # Filter location fields to only include Bedrock-supported ones - if "documentIndex" in location: - filtered_location["documentIndex"] = location["documentIndex"] - if "start" in location: - filtered_location["start"] = location["start"] - if "end" in location: - filtered_location["end"] = location["end"] - filtered_citation["location"] = filtered_location + filtered_citation["location"] = citation["location"] if "sourceContent" in citation: filtered_source_content: list[dict[str, Any]] = [] for source_content in citation["sourceContent"]: diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index c3890f428..d64357cf8 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -161,7 +161,7 @@ class CitationStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, citation: Citation) -> None: """Initialize with delta and citation content.""" - super().__init__({"callback": {"citation": citation, "delta": delta}}) + super().__init__({"citation": citation, "delta": delta}) class ReasoningTextStreamEvent(ModelStreamEvent): diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..41f2fa4e0 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union +from typing import List, Literal, Union from typing_extensions import TypedDict @@ -77,8 +77,17 @@ class DocumentPageLocation(TypedDict, total=False): end: int -# Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +# Tagged union type aliases following the ToolChoice pattern +DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation] +DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation] +DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation] + +# Union type for citation locations - tagged union format matching AWS Bedrock API +CitationLocation = Union[ + DocumentCharLocationDict, + DocumentPageLocationDict, + DocumentChunkLocationDict, +] class CitationSourceContent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 02be400b1..c6e44b78a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -215,6 +215,59 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, ), + # Citation - New + ( + { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + }, + {}, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + ), + # Citation - Existing + ( + { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + }, + {}, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ] + }, + { + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"}, + {"location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, "title": "Another Doc"}, + ] + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + ), # Empty ( {"delta": {}}, @@ -294,14 +347,49 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), - # Citations + # Text with Citations + ( + { + "content": [], + "current_tool_use": {}, + "text": "This is cited text", + "reasoningText": "", + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], + "redactedContent": b"", + }, + { + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + ], + "content": [{"text": "This is cited text"}], + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), + # Citations without text (should not create content block) ( { "content": [], "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, { @@ -309,7 +397,9 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "current_tool_use": {}, "text": "", "reasoningText": "", - "citationsContent": [{"citations": [{"text": "test", "source": "test"}]}], + "citationsContent": [ + {"location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, "title": "Test Doc"} + ], "redactedContent": b"", }, ), @@ -578,6 +668,137 @@ def test_extract_usage_metrics_empty_metadata(): }, ], ), + # Message with Citations + ( + [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + }, + { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + }, + ], + [ + {"event": {"messageStart": {"role": "assistant"}}}, + {"event": {"contentBlockStart": {"start": {}}}}, + {"event": {"contentBlockDelta": {"delta": {"text": "This is cited text"}}}}, + {"data": "This is cited text", "delta": {"text": "This is cited text"}}, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + }, + "delta": { + "citation": { + "location": {"documentChar": {"documentIndex": 0, "start": 10, "end": 20}}, + "title": "Test Doc", + } + }, + }, + { + "event": { + "contentBlockDelta": { + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + } + } + } + }, + { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + }, + "delta": { + "citation": { + "location": {"documentPage": {"documentIndex": 1, "start": 5, "end": 6}}, + "title": "Another Doc", + } + }, + }, + {"event": {"contentBlockStop": {}}}, + {"event": {"messageStop": {"stopReason": "end_turn"}}}, + { + "event": { + "metadata": { + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + }, + { + "stop": ( + "end_turn", + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "documentChar": {"documentIndex": 0, "start": 10, "end": 20} + }, + "title": "Test Doc", + }, + { + "location": { + "documentPage": {"documentIndex": 1, "start": 5, "end": 6} + }, + "title": "Another Doc", + }, + ], + "content": [{"text": "This is cited text"}], + } + } + ], + }, + {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, + {"latencyMs": 100}, + ) + }, + ], + ), # Empty Message ( [{}], diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..5ec5a7072 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2070,3 +2070,76 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_citations_content_preserves_tagged_union_structure(bedrock_client, model, alist): + """Test that citationsContent preserves AWS Bedrock's required tagged union structure for citation locations. + + This test verifies that when messages contain citationsContent with tagged union CitationLocation objects, + the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a + tagged union with exactly one wrapper key (documentChar, documentPage, etc.) containing the location fields. + """ + # Mock the Bedrock response + bedrock_client.converse_stream.return_value = {"stream": []} + + # Messages with citationsContent using tagged union CitationLocation structure + messages = [ + {"role": "user", "content": [{"text": "Analyze this document"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [ + {"text": "Employee benefits include health insurance and retirement plans"} + ], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ], + "content": [{"text": "Based on the document, employees receive comprehensive benefits."}], + } + } + ], + }, + ] + + # Call the public stream method + await alist(model.stream(messages)) + + # Verify the request sent to Bedrock preserves the tagged union structure + bedrock_client.converse_stream.assert_called_once() + call_args = bedrock_client.converse_stream.call_args[1] + + # Extract the citationsContent from the formatted messages + formatted_messages = call_args["messages"] + citations_content = formatted_messages[1]["content"][0]["citationsContent"] + + # Verify the tagged union structure is preserved + expected_citations = [ + { + "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, + "sourceContent": [{"text": "Employee benefits include health insurance and retirement plans"}], + "title": "Benefits Section", + }, + { + "location": {"documentPage": {"documentIndex": 0, "start": 2, "end": 3}}, + "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], + "title": "Vacation Policy", + }, + ] + + assert citations_content["citations"] == expected_citations, ( + "Citation location tagged union structure was not preserved. " + "AWS Bedrock requires CitationLocation to have exactly one wrapper key " + "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " + "with the location fields nested inside." + ) diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index d64cabb83..6163faeb6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -195,8 +195,8 @@ def test_initialization(self): delta = Mock(spec=ContentBlockDelta) citation = Mock(spec=Citation) event = CitationStreamEvent(delta, citation) - assert event["callback"]["citation"] == citation - assert event["callback"]["delta"] == delta + assert event["citation"] == citation + assert event["delta"] == delta class TestReasoningTextStreamEvent: diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 2c2e125ad..b31f23663 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -210,6 +210,9 @@ def test_document_citations(non_streaming_agent, letter_pdf): assert any("citationsContent" in content for content in non_streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + non_streaming_agent("What is your favorite part?") + def test_document_citations_streaming(streaming_agent, letter_pdf): content: list[ContentBlock] = [ @@ -228,6 +231,9 @@ def test_document_citations_streaming(streaming_agent, letter_pdf): assert any("citationsContent" in content for content in streaming_agent.messages[-1]["content"]) + # Validate message structure is valid in multi-turn + streaming_agent("What is your favorite part?") + def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow_color): content = [ From 4342fda23a76bdc20db80eefd65f04e5c60c2357 Mon Sep 17 00:00:00 2001 From: rajib Date: Thu, 18 Dec 2025 10:00:34 -0800 Subject: [PATCH 015/279] fix(telemetry): prevent double counting of usage metrics (#1327) * fix(telemetry): remove duplicate accumulation of usage metrics to prevent double counting * add langfuse observation type to invoke_agent to prevent double counting --------- Co-authored-by: RAJIB DEB Co-authored-by: poshinchen --- src/strands/telemetry/tracer.py | 4 ++++ tests/strands/telemetry/test_tracer.py | 30 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 2f42d9988..d16b37fc8 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -667,6 +667,10 @@ def end_agent_span( ) if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): + if "langfuse" in os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") or "langfuse" in os.getenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "" + ): + attributes.update({"langfuse.observation.type": "span"}) accumulated_usage = response.metrics.accumulated_usage attributes.update( { diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 205748956..cb98b8130 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -794,6 +794,36 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() +def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch): + """Test ending an agent span with Langfuse observation type to prevent double counting the tokens.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + tracer = Tracer() + + # Mock AgentResult with metrics + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + mock_span.set_attribute.assert_any_call("langfuse.observation.type", "span") + mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.add_event.assert_any_call( + "gen_ai.choice", + attributes={"message": "Agent response", "finish_reason": "end_turn"}, + ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_end_agent_span_latest_conventions(mock_span, monkeypatch): """Test ending an agent span with the latest semantic conventions.""" monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") From 3d03a3555f18c9df1d343c268652c27a4a0fa11d Mon Sep 17 00:00:00 2001 From: Danilo Poccia Date: Fri, 19 Dec 2025 17:37:25 +0000 Subject: [PATCH 016/279] feat: add support for web and search result citations (#1344) Co-authored-by: Dean Schmigelski --- src/strands/types/citations.py | 39 +++++++++++++++++ tests/strands/models/test_bedrock.py | 63 +++++++++++++++++++++++++--- 2 files changed, 97 insertions(+), 5 deletions(-) diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index 41f2fa4e0..623f6ddc7 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -77,16 +77,55 @@ class DocumentPageLocation(TypedDict, total=False): end: int +class SearchResultLocation(TypedDict, total=False): + """Specifies a search result location within the content array. + + Provides positioning information for cited content using search result + index and block positions. + + Attributes: + searchResultIndex: The index of the search result content block where + the cited content is found. Minimum value of 0. + start: The starting position in the content array where the cited + content begins. Minimum value of 0. + end: The ending position in the content array where the cited + content ends. Minimum value of 0. + """ + + searchResultIndex: int + start: int + end: int + + +class WebLocation(TypedDict, total=False): + """Provides the URL and domain information for a cited website. + + Contains information about the website that was cited when performing + a web search. + + Attributes: + url: The URL that was cited when performing a web search. + domain: The domain that was cited when performing a web search. + """ + + url: str + domain: str + + # Tagged union type aliases following the ToolChoice pattern DocumentCharLocationDict = dict[Literal["documentChar"], DocumentCharLocation] DocumentPageLocationDict = dict[Literal["documentPage"], DocumentPageLocation] DocumentChunkLocationDict = dict[Literal["documentChunk"], DocumentChunkLocation] +SearchResultLocationDict = dict[Literal["searchResultLocation"], SearchResultLocation] +WebLocationDict = dict[Literal["web"], WebLocation] # Union type for citation locations - tagged union format matching AWS Bedrock API CitationLocation = Union[ DocumentCharLocationDict, DocumentPageLocationDict, DocumentChunkLocationDict, + SearchResultLocationDict, + WebLocationDict, ] diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5ec5a7072..33be44b1b 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2078,14 +2078,15 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client This test verifies that when messages contain citationsContent with tagged union CitationLocation objects, the structure is preserved when sent to AWS Bedrock API. AWS Bedrock expects CitationLocation to be a - tagged union with exactly one wrapper key (documentChar, documentPage, etc.) containing the location fields. + tagged union with exactly one wrapper key (documentChar, documentPage, documentChunk, searchResultLocation, web) + containing the location fields. """ # Mock the Bedrock response bedrock_client.converse_stream.return_value = {"stream": []} - # Messages with citationsContent using tagged union CitationLocation structure + # Messages with citationsContent using all tagged union CitationLocation types messages = [ - {"role": "user", "content": [{"text": "Analyze this document"}]}, + {"role": "user", "content": [{"text": "Analyze multiple sources"}]}, { "role": "assistant", "content": [ @@ -2104,8 +2105,34 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], "title": "Vacation Policy", }, + { + "location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}}, + "sourceContent": [{"text": "Company culture emphasizes work-life balance"}], + "title": "Culture Section", + }, + { + "location": { + "searchResultLocation": { + "searchResultIndex": 0, + "start": 25, + "end": 150, + } + }, + "sourceContent": [{"text": "Search results show industry best practices"}], + "title": "Search Results", + }, + { + "location": { + "web": { + "url": "https://example.com/hr-policies", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "External HR policy guidelines"}], + "title": "External Reference", + }, ], - "content": [{"text": "Based on the document, employees receive comprehensive benefits."}], + "content": [{"text": "Based on multiple sources, the company offers comprehensive benefits."}], } } ], @@ -2123,7 +2150,7 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client formatted_messages = call_args["messages"] citations_content = formatted_messages[1]["content"][0]["citationsContent"] - # Verify the tagged union structure is preserved + # Verify the tagged union structure is preserved for all location types expected_citations = [ { "location": {"documentChar": {"documentIndex": 0, "start": 150, "end": 300}}, @@ -2135,6 +2162,32 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client "sourceContent": [{"text": "Vacation policy allows 15 days per year"}], "title": "Vacation Policy", }, + { + "location": {"documentChunk": {"documentIndex": 1, "start": 5, "end": 8}}, + "sourceContent": [{"text": "Company culture emphasizes work-life balance"}], + "title": "Culture Section", + }, + { + "location": { + "searchResultLocation": { + "searchResultIndex": 0, + "start": 25, + "end": 150, + } + }, + "sourceContent": [{"text": "Search results show industry best practices"}], + "title": "Search Results", + }, + { + "location": { + "web": { + "url": "https://example.com/hr-policies", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "External HR policy guidelines"}], + "title": "External Reference", + }, ] assert citations_content["citations"] == expected_citations, ( From 3cb39a64ab80c9751dfc922b31e145f93846d46e Mon Sep 17 00:00:00 2001 From: pshiko Date: Sat, 20 Dec 2025 03:04:49 +0900 Subject: [PATCH 017/279] feat: add gemini_tools field to GeminiModel with validation and tests (#1050) Add support for Gemini-specific tools like GoogleSearch and CodeExecution, with validation to prevent FunctionDeclarations. Due to the fundamental differences in how Gemini's built-in tools operate (server-side execution without explicit tool call/result blocks), we don't implement history tracking for gemini_tools - that would require additional design work and a longer discussion on how to normalize this across all model providers. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow --- .../steering/handlers/__init__.py | 4 +- src/strands/models/gemini.py | 43 +++++++++- tests/strands/models/test_gemini.py | 83 +++++++++++++++++++ tests_integ/models/test_model_gemini.py | 27 ++++++ 4 files changed, 155 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index ca529530f..542126ab5 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,3 +1,5 @@ """Steering handler implementations.""" -__all__ = [] +from typing import Sequence + +__all__: Sequence[str] = [] diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c24d91a0d..22feecf32 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -40,10 +40,16 @@ class GeminiConfig(TypedDict, total=False): params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://ai.google.dev/api/generate-content#generationconfig. + gemini_tools: Gemini-specific tools that are not FunctionDeclarations + (e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch). + Use the standard tools interface for function calling tools. + For a complete list of supported tools, see + https://ai.google.dev/api/caching#Tool """ model_id: Required[str] params: dict[str, Any] + gemini_tools: list[genai.types.Tool] def __init__( self, @@ -61,6 +67,10 @@ def __init__( validate_config_keys(model_config, GeminiModel.GeminiConfig) self.config = GeminiModel.GeminiConfig(**model_config) + # Validate gemini_tools if provided + if "gemini_tools" in self.config: + self._validate_gemini_tools(self.config["gemini_tools"]) + logger.debug("config=<%s> | initializing", self.config) self.client_args = client_args or {} @@ -72,6 +82,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + # Validate gemini_tools if provided + if "gemini_tools" in model_config: + self._validate_gemini_tools(model_config["gemini_tools"]) + self.config.update(model_config) @override @@ -181,7 +195,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge Return: Gemini tool list. """ - return [ + tools = [ genai.types.Tool( function_declarations=[ genai.types.FunctionDeclaration( @@ -193,6 +207,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge ], ), ] + if self.config.get("gemini_tools"): + tools.extend(self.config["gemini_tools"]) + return tools def _format_request_config( self, @@ -451,3 +468,27 @@ async def structured_output( client = genai.Client(**self.client_args).aio response = await client.models.generate_content(**request) yield {"output": output_model.model_validate(response.parsed)} + + @staticmethod + def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None: + """Validate that gemini_tools does not contain FunctionDeclarations. + + Gemini-specific tools should only include tools that cannot be represented + as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse). + Standard function calling tools should use the tools interface instead. + + Args: + gemini_tools: List of Gemini tools to validate + + Raises: + ValueError: If any tool contains function_declarations + """ + for tool in gemini_tools: + # Check if the tool has function_declarations attribute and it's not empty + if hasattr(tool, "function_declarations") and tool.function_declarations: + raise ValueError( + "gemini_tools should not contain FunctionDeclarations. " + "Use the standard tools interface for function calling tools. " + "gemini_tools is reserved for Gemini-specific tools like " + "GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch." + ) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index a8f5351cc..8e8742f94 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -624,6 +624,89 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath gemini_client.aio.models.generate_content.assert_called_with(**exp_request) +def test_gemini_tools_validation_rejects_function_declarations(model_id): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations]) + + +def test_gemini_tools_validation_allows_non_function_tools(model_id): + tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch()) + + model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search]) + assert "gemini_tools" in model.config + + +def test_gemini_tools_validation_on_update_config(model): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + model.update_config(gemini_tools=[tool_with_function_declarations]) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id): + google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch()) + model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool]) + + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [ + {"function_declarations": []}, + {"google_search": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id): + code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution()) + model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool]) + + await anext(model.stream(messages, tool_specs=[tool_spec])) + + exp_request = { + "config": { + "tools": [ + { + "function_declarations": [ + { + "description": tool_spec["description"], + "name": tool_spec["name"], + "parameters_json_schema": tool_spec["inputSchema"]["json"], + } + ] + }, + {"code_execution": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_handles_non_json_error(gemini_client, model, messages, caplog, alist): error_message = "Invalid API key" diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index f9da8490c..5643d159e 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -2,6 +2,7 @@ import pydantic import pytest +from google import genai import strands from strands import Agent @@ -21,6 +22,16 @@ def model(): ) +@pytest.fixture +def gemini_tool_model(): + return GeminiModel( + client_args={"api_key": os.getenv("GOOGLE_API_KEY")}, + model_id="gemini-2.5-flash", + params={"temperature": 0.15}, # Lower temperature for consistent test behavior + gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())], + ) + + @pytest.fixture def tools(): @strands.tool @@ -175,3 +186,19 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow tru_color = assistant_agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_agent_with_gemini_code_execution_tool(gemini_tool_model): + system_prompt = "Generate and run code for all calculations" + agent = Agent(model=gemini_tool_model, system_prompt=system_prompt) + # sample prompt taken from https://ai.google.dev/gemini-api/docs/code-execution + result_turn1 = agent( + "What is the sum of the first 50 prime numbers? Generate and run code for the calculation, " + "and make sure you get all 50." + ) + + # NOTE: We don't verify tool history because built-in tools are currently represented in message history + assert "5117" in str(result_turn1) + + result_turn2 = agent("Summarize that into a single number") + assert "5117" in str(result_turn2) From 894ba80c71335b8823b77a6911faa93b92dfbc93 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:57:18 -0500 Subject: [PATCH 018/279] Port PR guidelines from sdk-typescript (#1373) ## Description Port PR guidelines from sdk-typescript so that agents & contributors can use them. Updated the PR.md with appropriate python examples and referenced the doc from other files. Related files from sdk-typescript: - https://github.com/strands-agents/sdk-typescript/blob/main/docs/PR.md - https://github.com/strands-agents/sdk-typescript/blob/main/AGENTS.md Co-authored-by: Mackenzie Zastrow --- AGENTS.md | 16 +++- CONTRIBUTING.md | 2 + docs/PR.md | 201 ++++++++++++++++++++++++++++++++++++++++++++++++ docs/README.md | 1 + 4 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 docs/PR.md diff --git a/AGENTS.md b/AGENTS.md index 49ea8a656..8b4394cc5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -197,6 +197,8 @@ strands-agents/ ├── docs/ # Developer documentation │ ├── README.md # Docs folder overview │ ├── STYLE_GUIDE.md # Code style conventions +│ ├── HOOKS.md # Hooks system guide +│ ├── PR.md # PR description guidelines │ └── MCP_CLIENT_ARCHITECTURE.md # MCP threading architecture │ ├── pyproject.toml # Project config (build, deps, tools) @@ -230,7 +232,18 @@ pre-commit install -t pre-commit -t commit-msg # Install hooks 4. Commit with conventional commits (`feat:`, `fix:`, `docs:`, `refactor:`, `test:`, `chore:`) 5. Push and open PR -### 3. Quality Gates +### 3. Pull Request Guidelines + +When creating pull requests, you MUST follow the guidelines in PR.md. Key principles: + +Focus on WHY: Explain motivation and user impact, not implementation details +Document public API changes: Show before/after code examples +Be concise: Use prose over bullet lists; avoid exhaustive checklists +Target senior engineers: Assume familiarity with the SDK +Exclude implementation details: Leave these to code comments and diffs +See PR.md for the complete guidance and template. + +### 4. Quality Gates Pre-commit hooks run automatically on commit: - Formatting (ruff) @@ -474,4 +487,5 @@ hatch build # Build package - [docs/](./docs/) - Developer documentation - [STYLE_GUIDE.md](./docs/STYLE_GUIDE.md) - Code style conventions - [HOOKS.md](./docs/HOOKS.md) - Hooks system guide + - [PR.md](./docs/PR.md) - PR description guidelines - [MCP_CLIENT_ARCHITECTURE.md](./docs/MCP_CLIENT_ARCHITECTURE.md) - MCP threading design diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0e01fc38d..86691a2d7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -132,6 +132,8 @@ Contributions via pull requests are much appreciated. Before sending us a pull r 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. +For guidance on writing effective PR descriptions, see our [PR Description Guidelines](./docs/PR.md). + To send us a pull request, please: 1. Create a branch. diff --git a/docs/PR.md b/docs/PR.md new file mode 100644 index 000000000..b4778f2b1 --- /dev/null +++ b/docs/PR.md @@ -0,0 +1,201 @@ +# Pull Request Description Guidelines + +Good PR descriptions help reviewers understand the context and impact of your changes. They enable faster reviews, better decision-making, and serve as valuable historical documentation. + +When creating a PR, follow the [GitHub PR template](../.github/PULL_REQUEST_TEMPLATE.md) and use these guidelines to fill it out effectively. + +## Who's Reading Your PR? + +Write for senior engineers familiar with the SDK. Assume your reader: + +- Understands the SDK's architecture and patterns +- Has context about the broader system +- Can read code diffs to understand implementation details +- Values concise, focused communication + +## What to Include + +Every PR description should have: + +1. **Motivation** — Why is this change needed? +2. **Public API Changes** — What changes to the public API (with code snippets)? +3. **Use Cases** (optional) — When would developers use this feature? Only include for non-obvious functionality; skip for trivial changes or obvious fixes. +4. **Breaking Changes** (if applicable) — What breaks and how to migrate? + +## Writing Principles + +**Focus on WHY, not HOW:** + +- ✅ "Hook providers need access to the agent's result to perform post-invocation actions like logging or analytics" +- ❌ "Added result field to AfterInvocationEvent dataclass" + +**Document public API changes with example code snippets:** + +- ✅ Show before/after code snippets for API changes +- ❌ List every file or line changed + +**Be concise:** + +- ✅ Use prose over bullet lists when possible +- ❌ Create exhaustive implementation checklists + +**Emphasize user impact:** + +- ✅ "Enables hooks to log conversation outcomes or trigger follow-up actions based on the result" +- ❌ "Updated AfterInvocationEvent to include optional AgentResult field" + +## What to Skip + +Leave these out of your PR description: + +- **Implementation details** — Code comments and commit messages cover this +- **Test coverage notes** — CI will catch issues; assume tests are comprehensive +- **Line-by-line change lists** — The diff provides this +- **Build/lint/coverage status** — CI handles verification +- **Commit hashes** — GitHub links commits automatically + +## Anti-patterns + +❌ **Over-detailed checklists:** + +```markdown +### Type Definition Updates + +- Added result field to AfterInvocationEvent dataclass +- Updated Agent._run_loop to capture and pass AgentResult +``` + +❌ **Implementation notes reviewers don't need:** + +```markdown +## Implementation Notes + +- Result field defaults to None +- AgentResult is captured from EventLoopStopEvent before invoking hooks +``` + +❌ **Test coverage bullets:** + +```markdown +### Test Coverage + +- Added test: AfterInvocationEvent includes AgentResult +- Added test: result is None when structured_output is used +``` + +## Good Examples + +✅ **Motivation section:** + +```markdown +## Motivation + +Hook providers often need to perform actions based on the outcome of an agent's +invocation, such as logging results, updating metrics, or triggering follow-up +workflows. Currently, the `AfterInvocationEvent` doesn't provide access to the +`AgentResult`, forcing hook implementations to track state externally or miss +this information entirely. +``` + +✅ **Public API Changes section:** + +````markdown +## Public API Changes + +`AfterInvocationEvent` now includes an optional `result` attribute containing +the `AgentResult`: + +```python +# Before: no access to result +class MyHook(HookProvider): + def on_after_invocation(self, event: AfterInvocationEvent) -> None: + # Could only access event.agent, no result available + logger.info("Invocation completed") + +# After: result available for inspection +class MyHook(HookProvider): + def on_after_invocation(self, event: AfterInvocationEvent) -> None: + if event.result: + logger.info(f"Completed with stop_reason: {event.result.stop_reason}") +``` + +The `result` field is `None` when invoked from `structured_output` methods. + +```` + +✅ **Use Cases section:** + +```markdown +## Use Cases + +- **Result logging**: Log conversation outcomes including stop reasons and token usage +- **Analytics**: Track agent performance metrics based on invocation results +- **Conditional workflows**: Trigger follow-up actions based on how the agent completed +```` + +## Template + +````markdown +## Motivation + +[Explain WHY this change is needed. What problem does it solve? What limitation +does it address? What user need does it fulfill?] + +Resolves: #[issue-number] + +## Public API Changes + +[Document changes to public APIs with before/after code snippets. If no public +API changes, state "No public API changes."] + +```python +# Before +[existing API usage] + +# After +[new API usage] +``` + +[Explain behavior, parameters, return values, and backward compatibility.] + +## Use Cases (optional) + +[Only include for non-obvious functionality. Provide 1-3 concrete use cases +showing when developers would use this feature. Skip for trivial changes obvious fixes..] + +## Breaking Changes (if applicable) + +[If this is a breaking change, explain what breaks and provide migration guidance.] + +### Migration + +```python +# Before +[old code] + +# After +[new code] +``` + +```` + +## Why These Guidelines? + +**Focus on WHY over HOW** because code diffs show implementation details, commit messages document granular changes, and PR descriptions provide the broader context reviewers need. + +**Skip test/lint/coverage details** because CI pipelines verify these automatically. Including them adds noise without value. + +**Write for senior engineers** to enable concise, technical communication without redundant explanations. + +## References + +- [Conventional Commits](https://www.conventionalcommits.org/) +- [Google's Code Review Guidelines](https://google.github.io/eng-practices/review/) + +## Checklist Items + + - [ ] Does the PR description target a Senior Engineer familiar with the project? + - [ ] Does the PR description give an overview of the feature being implemented, including any notes on key implementation decisions + - [ ] Does the PR include a "Resolves #" in the body and is not bolded? + - [ ] Does the PR contain the motivation or use-cases behind the change? + - [ ] Does the PR omit irrelevant details not needed for historical reference? \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 4ad4ee44f..857edc4c4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,6 +6,7 @@ This folder contains documentation for contributors and developers working on th - [STYLE_GUIDE.md](./STYLE_GUIDE.md) - Code style conventions and formatting guidelines - [HOOKS.md](./HOOKS.md) - Hooks system rules and usage guide +- [PR.md](./PR.md) - Pull request description guidelines - [MCP_CLIENT_ARCHITECTURE.md](./MCP_CLIENT_ARCHITECTURE.md) - MCP client threading architecture and design decisions ## Related Documentation From 0c640e838e4f5ffc31d07a1f961305a7ab38d8b5 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Sun, 21 Dec 2025 11:56:17 -0500 Subject: [PATCH 019/279] feat: allow custom-client for OpenAIModel and GeminiModel (#1366) --- src/strands/models/gemini.py | 44 ++++++++++++++-- src/strands/models/openai.py | 62 ++++++++++++++++++++-- tests/strands/models/test_gemini.py | 74 +++++++++++++++++++++++++++ tests/strands/models/test_openai.py | 79 ++++++++++++++++++++++++++++- 4 files changed, 249 insertions(+), 10 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 22feecf32..cf7cc604a 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -54,27 +54,44 @@ class GeminiConfig(TypedDict, total=False): def __init__( self, *, + client: Optional[genai.Client] = None, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. Args: + client: Pre-configured Gemini client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + Note: The client should not be shared across different asyncio event loops. client_args: Arguments for the underlying Gemini client (e.g., api_key). For a complete list of supported arguments, see https://googleapis.github.io/python-genai/. **model_config: Configuration options for the Gemini model. + + Raises: + ValueError: If both `client` and `client_args` are provided. """ validate_config_keys(model_config, GeminiModel.GeminiConfig) self.config = GeminiModel.GeminiConfig(**model_config) + # Validate that only one client configuration method is provided + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + + self._custom_client = client + self.client_args = client_args or {} + # Validate gemini_tools if provided if "gemini_tools" in self.config: self._validate_gemini_tools(self.config["gemini_tools"]) logger.debug("config=<%s> | initializing", self.config) - self.client_args = client_args or {} - @override def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override] """Update the Gemini model configuration with the provided arguments. @@ -97,6 +114,24 @@ def get_config(self) -> GeminiConfig: """ return self.config + def _get_client(self) -> genai.Client: + """Get a Gemini client for making requests. + + This method handles client lifecycle management: + - If an injected client was provided during initialization, it returns that client + without managing its lifecycle (caller is responsible for cleanup). + - Otherwise, creates a new genai.Client from client_args. + + Returns: + genai.Client: A Gemini client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + return self._custom_client + else: + # Create a new client from client_args + return genai.Client(**self.client_args) + def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part: """Format content block into a Gemini part instance. @@ -382,7 +417,8 @@ async def stream( """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio + try: response = await client.models.generate_content_stream(**request) @@ -465,7 +501,7 @@ async def structured_output( "response_schema": output_model.model_json_schema(), } request = self._format_request(prompt, None, system_prompt, params) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio response = await client.models.generate_content(**request) yield {"output": output_model.model_validate(response.parsed)} diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 435c82cab..07246c5d6 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False): model_id: str params: Optional[dict[str, Any]] - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + def __init__( + self, + client: Optional[Client] = None, + client_args: Optional[dict[str, Any]] = None, + **model_config: Unpack[OpenAIConfig], + ) -> None: """Initialize provider instance. Args: - client_args: Arguments for the OpenAI client. + client: Pre-configured OpenAI-compatible client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI) + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + - Pointing to custom model gateways + Note: The client should not be shared across different asyncio event loops. + client_args: Arguments for the OpenAI client (legacy approach). For a complete list of supported arguments, see https://pypi.org/project/openai/. **model_config: Configuration options for the OpenAI model. + + Raises: + ValueError: If both `client` and `client_args` are provided. """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) + + # Validate that only one client configuration method is provided + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + + self._custom_client = client self.client_args = client_args or {} logger.debug("config=<%s> | initializing", self.config) @@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @asynccontextmanager + async def _get_client(self) -> AsyncIterator[Any]: + """Get an OpenAI client for making requests. + + This context manager handles client lifecycle management: + - If an injected client was provided during initialization, it yields that client + without closing it (caller manages lifecycle). + - Otherwise, creates a new AsyncOpenAI client from client_args and automatically + closes it when the context exits. + + Note: We create a new client per request to avoid connection sharing in the underlying + httpx client, as the asyncio event loop does not allow connections to be shared. + For more details, see https://github.com/encode/httpx/discussions/2959. + + Yields: + Client: An OpenAI-compatible client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + yield self._custom_client + else: + # Create a new client from client_args + # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying + # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please + # refer to https://github.com/encode/httpx/discussions/2959. + async with openai.AsyncOpenAI(**self.client_args) as client: + yield client + @override async def stream( self, @@ -457,7 +509,7 @@ async def stream( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response = await client.chat.completions.create(**request) except openai.BadRequestError as e: @@ -576,7 +628,7 @@ async def structured_output( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response: ParsedChatCompletion = await client.beta.chat.completions.parse( model=self.get_config()["model_id"], diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 8e8742f94..c552a892a 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -720,3 +720,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap assert "Gemini API returned non-JSON error" in caplog.text assert f"error_message=<{error_message}>" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[genai.types.Part(text="Hello")], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + ] + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content_stream.assert_called_once() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, weather_output, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock( + parsed=weather_output.model_dump() + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "Generate weather"}]}] + stream = model.structured_output(type(weather_output), messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": weather_output} + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content.assert_called_once() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.Mock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0de0c4ebc..ef173d349 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -13,7 +13,10 @@ def openai_client(): with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: mock_client = unittest.mock.AsyncMock() - mock_client_cls.return_value.__aenter__.return_value = mock_client + # Make the mock client work as an async context manager + mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client) + mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client yield mock_client @@ -986,3 +989,77 @@ def test_format_request_messages_drops_cache_points(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + + mock_injected_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.chat.completions.create.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = mock_parsed_instance + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": test_output_model_cls(name="John", age=30)} + + # Verify the injected client was used + mock_injected_client.beta.chat.completions.parse.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") From 1907a16281d613e31f3f6a5e3cbbd5ba50c6e024 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:06:26 -0500 Subject: [PATCH 020/279] fix: Pass CODECOV_TOKENS through for code-coverage stats (#1385) We need to plumb this through so that protected branches (like main) correctly report code coverage to codecov. Prior to this PRs did not have incremental diffs because main builds didn't have this token; only `main` needs the token because it's a protected branch [0]. This token is already a secret in this repository. [0]: https://docs.codecov.com/docs/codecov-tokens#when-do-i-need-a-token Co-authored-by: Mackenzie Zastrow --- .github/workflows/pr-and-push.yml | 2 ++ .github/workflows/pypi-publish-on-release.yml | 2 ++ .github/workflows/test-lint.yml | 3 +++ 3 files changed, 7 insertions(+) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index b558943dd..dcebbd7da 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -17,3 +17,5 @@ jobs: contents: read with: ref: ${{ github.event.pull_request.head.sha }} + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index ff19e46b1..2d748c8c6 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -12,6 +12,8 @@ jobs: contents: read with: ref: ${{ github.event.release.target_commitish }} + secrets: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} build: name: Build distribution 📦 diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 4986acf1f..549d0b21d 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -6,6 +6,9 @@ on: ref: required: true type: string + secrets: + CODECOV_TOKEN: + required: false jobs: unit-test: From 138f5abcfdcb74d6d0e402ca1dd7ceff9bdc79fb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:52:15 -0500 Subject: [PATCH 021/279] ci: bump actions/checkout from 5 to 6 (#1222) Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- .github/workflows/pypi-publish-on-release.yml | 2 +- .github/workflows/test-lint.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 7496e45ef..397f4300d 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -52,7 +52,7 @@ jobs: aws-region: us-east-1 mask-aws-account-id: true - name: Checkout head commit - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo persist-credentials: false # Don't persist credentials for subsequent actions diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 2d748c8c6..bad8da5af 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: persist-credentials: false diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 549d0b21d..8f393d5de 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -54,7 +54,7 @@ jobs: LOG_LEVEL: DEBUG steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ inputs.ref }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions @@ -95,7 +95,7 @@ jobs: contents: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: ref: ${{ inputs.ref }} persist-credentials: false From 20ae18c218fa0fc61c3c04c70cd98d95ca1a7189 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 22 Dec 2025 15:53:31 -0500 Subject: [PATCH 022/279] ci: update pytest-asyncio requirement (#1166) Updates the requirements on [pytest-asyncio](https://github.com/pytest-dev/pytest-asyncio) to permit the latest version. - [Release notes](https://github.com/pytest-dev/pytest-asyncio/releases) - [Commits](https://github.com/pytest-dev/pytest-asyncio/compare/v1.0.0...v1.3.0) --- updated-dependencies: - dependency-name: pytest-asyncio dependency-version: 1.3.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c2a6b260..ea5f2f166 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ dev = [ "pre-commit>=3.2.0,<4.4.0", "pytest>=8.0.0,<9.0.0", "pytest-cov>=7.0.0,<8.0.0", - "pytest-asyncio>=1.0.0,<1.3.0", + "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.14.0", ] @@ -144,7 +144,7 @@ extra-args = ["-n", "auto", "-vv"] dependencies = [ "pytest>=8.0.0,<9.0.0", "pytest-cov>=7.0.0,<8.0.0", - "pytest-asyncio>=1.0.0,<1.3.0", + "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", "moto>=5.1.0,<6.0.0", ] From 87caf1c37f471f7b9b19f3aef86f9b354329cd42 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:21:09 -0500 Subject: [PATCH 023/279] ci: bump actions/upload-artifact from 4 to 6 (#1332) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4 to 6. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v4...v6) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pypi-publish-on-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index bad8da5af..506f8023c 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -54,7 +54,7 @@ jobs: hatch build - name: Store the distribution packages - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: python-package-distributions path: dist/ From 2c0aab0df6c5c878bc9dcf8c9be9a0c50bd15692 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:21:38 -0500 Subject: [PATCH 024/279] ci: bump actions/download-artifact from 5 to 7 (#1333) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 5 to 7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v5...v7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pypi-publish-on-release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 506f8023c..bf2c9f21d 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -76,7 +76,7 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v7 with: name: python-package-distributions path: dist/ From c1c24efce3fe3c23c7b3562f81ff5cb507e165f2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 09:24:42 -0500 Subject: [PATCH 025/279] ci: update pre-commit requirement from <4.4.0,>=3.2.0 to >=3.2.0,<4.6.0 (#1242) Updates the requirements on [pre-commit](https://github.com/pre-commit/pre-commit) to permit the latest version. - [Release notes](https://github.com/pre-commit/pre-commit/releases) - [Changelog](https://github.com/pre-commit/pre-commit/blob/main/CHANGELOG.md) - [Commits](https://github.com/pre-commit/pre-commit/compare/v3.2.0...v4.5.0) --- updated-dependencies: - dependency-name: pre-commit dependency-version: 4.5.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea5f2f166..9232131d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dev = [ "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.4.0", + "pre-commit>=3.2.0,<4.6.0", "pytest>=8.0.0,<9.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", @@ -166,7 +166,7 @@ features = ["all"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", - "pre-commit>=3.2.0,<4.4.0", + "pre-commit>=3.2.0,<4.6.0", ] From ad2f201447f2021415c526da55ec8d01cdc0822f Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 23 Dec 2025 07:45:28 -0800 Subject: [PATCH 026/279] feat: add api check to github workflow (#1348) * feat: add api check to github workflow * Update pr-and-push.yml fix package name. * fix: Compare against the actual base of the PR or the previous commit on push, rather than the latest tag * fix: add more explicit error message * feat: add name to each step --- .github/workflows/pr-and-push.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index dcebbd7da..1045c2c9e 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -19,3 +19,22 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + check-api: + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v5 + with: + fetch-depth: 0 # We the need the full Git history. + - name: Setup uv + uses: astral-sh/setup-uv@v6 + - name: Check API breaking changes + run: | + if ! uvx griffe check --search src --format github strands --against "${{ github.event.pull_request.base.sha || github.event.before || 'HEAD' }}"; then + echo "Breaking API changes detected" + exit 1 + fi + From b2cc4c207928419c6cd64d6071c95a10590c6ffd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 10:46:44 -0500 Subject: [PATCH 027/279] ci: bump aws-actions/configure-aws-credentials from 4 to 5 (#1352) Bumps [aws-actions/configure-aws-credentials](https://github.com/aws-actions/configure-aws-credentials) from 4 to 5. - [Release notes](https://github.com/aws-actions/configure-aws-credentials/releases) - [Changelog](https://github.com/aws-actions/configure-aws-credentials/blob/main/CHANGELOG.md) - [Commits](https://github.com/aws-actions/configure-aws-credentials/compare/v4...v5) --- updated-dependencies: - dependency-name: aws-actions/configure-aws-credentials dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/issue-responder.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/issue-responder.yml b/.github/workflows/issue-responder.yml index 318b74361..c6cba59ab 100644 --- a/.github/workflows/issue-responder.yml +++ b/.github/workflows/issue-responder.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.STRANDS_AGENTCORE_ACTIONS_ROLE }} aws-region: us-west-2 From 033574bf7b77a13387fb8a2fcff8202e52453289 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 10:49:52 -0500 Subject: [PATCH 028/279] ci: update ruff requirement from <0.14.0,>=0.13.0 to >=0.13.0,<0.15.0 (#1004) Updates the requirements on [ruff](https://github.com/astral-sh/ruff) to permit the latest version. - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/0.13.0...0.14.0) --- updated-dependencies: - dependency-name: ruff dependency-version: 0.14.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9232131d6..040babe67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ dev = [ "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", - "ruff>=0.13.0,<0.14.0", + "ruff>=0.13.0,<0.15.0", ] [project.urls] @@ -114,7 +114,7 @@ installer = "uv" features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", - "ruff>=0.13.0,<0.14.0", + "ruff>=0.13.0,<0.15.0", # Include required package dependencies for mypy "strands-agents @ {root:uri}", ] From bf1b7aace26ab7caf827103d94b48f7efd127ef8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 26 Dec 2025 13:24:28 -0500 Subject: [PATCH 029/279] feat: add per_turn parameter to SlidingWindowConversationManager (#1374) Allow conversation managers to act as hook providers and add an option to built-in conversation managers to proactively apply message management during the agent loop execution. Use that functionality to add an option to SlidingWindowConversationManager to allow per_turn management application Fixes #509 --------- Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 3 + .../conversation_manager.py | 34 +++- .../sliding_window_conversation_manager.py | 86 ++++++++- src/strands/hooks/registry.py | 3 +- .../agent/test_conversation_manager.py | 174 ++++++++++++++++++ 5 files changed, 297 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8fc5be6ca..256c74415 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -250,6 +250,9 @@ def __init__( if self._session_manager: self.hooks.add_hook(self._session_manager) + # Allow conversation_managers to subscribe to hooks + self.hooks.add_hook(self.conversation_manager) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..47b761abc 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional +from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent -class ConversationManager(ABC): +class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. This class provides an interface for implementing conversation management strategies to control the size of message @@ -18,6 +19,18 @@ class ConversationManager(ABC): - Manage memory usage - Control context length - Maintain relevant conversation state + + ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent + lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper + hook registration. + + Example: + ```python + class MyConversationManager(ConversationManager): + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + super().register_hooks(registry, **kwargs) + # Register additional hooks here + ``` """ def __init__(self) -> None: @@ -30,6 +43,25 @@ def __init__(self) -> None: """ self.removed_message_count = 0 + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for agent lifecycle events. + + Derived classes that override this method must call the base implementation to ensure proper hook + registration chain. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + + Example: + ```python + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + super().register_hooks(registry, **kwargs) + registry.add_callback(SomeEvent, self.on_some_event) + ``` + """ + pass + def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: """Restore the Conversation Manager's state from a session. diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..a063e55eb 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent +from ...hooks import BeforeModelCallEvent, HookRegistry from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -18,19 +19,102 @@ class SlidingWindowConversationManager(ConversationManager): This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids invalid window states. + + Supports proactive management during agent loop execution via the per_turn parameter. """ - def __init__(self, window_size: int = 40, should_truncate_results: bool = True): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. should_truncate_results: Truncate tool results when a message is too large for the model's context window + per_turn: Controls when to apply message management during agent execution. + - False (default): Only apply management at the end (default behavior) + - True: Apply management before every model call + - int (e.g., 3): Apply management before every N model calls + + When to use per_turn: If your agent performs many tool operations in loops + (e.g., web browsing with frequent screenshots), enable per_turn to proactively + manage message history and prevent the agent loop from slowing down. Start with + per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed + for performance tuning. + + Raises: + ValueError: If per_turn is 0 or a negative integer. """ super().__init__() + self.window_size = window_size self.should_truncate_results = should_truncate_results + self.per_turn = per_turn + self._model_call_count = 0 + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook callbacks for per-turn conversation management. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + super().register_hooks(registry, **kwargs) + + # Always register the callback - per_turn check happens in the callback + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) + + def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: + """Handle before model call event for per-turn management. + + This callback is invoked before each model call. It tracks the model call count and applies message management + based on the per_turn configuration. + + Args: + event: The before model call event containing the agent and model execution details. + """ + # Check if per_turn is enabled + if self.per_turn is False: + return + + self._model_call_count += 1 + + # Determine if we should apply management + should_apply = False + if self.per_turn is True: + should_apply = True + elif isinstance(self.per_turn, int) and self.per_turn > 0: + should_apply = self._model_call_count % self.per_turn == 0 + + if should_apply: + logger.debug( + "model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management", + self._model_call_count, + self.per_turn, + ) + self.apply_management(event.agent) + + def get_state(self) -> dict[str, Any]: + """Get the current state of the conversation manager. + + Returns: + Dictionary containing the manager's state, including model call count for per-turn tracking. + """ + state = super().get_state() + state["model_call_count"] = self._model_call_count + return state + + def restore_from_session(self, state: dict[str, Any]) -> Optional[list]: + """Restore the conversation manager's state from a session. + + Args: + state: Previous state of the conversation manager + + Returns: + Optional list of messages to prepend to the agent's messages. + """ + result = super().restore_from_session(state) + self._model_call_count = state.get("model_call_count", 0) + return result def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 1efc0bf5b..9edf7ffa7 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,7 +10,7 @@ import inspect import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException @@ -84,6 +84,7 @@ class HookEvent(BaseHookEvent): """Generic for invoking events - non-contravariant to enable returning events.""" +@runtime_checkable class HookProvider(Protocol): """Protocol for objects that provide hook callbacks to an agent. diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 77d7dcce8..ae18a9131 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,9 +1,15 @@ +from unittest.mock import MagicMock, patch + import pytest +from strands import tool from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookProvider, HookRegistry from strands.types.exceptions import ContextWindowOverflowException +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -246,3 +252,171 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) + + +# ============================================================================== +# Per-Turn Management Tests +# ============================================================================== + + +def test_per_turn_parameter_validation(): + """Test per_turn parameter validation.""" + # Valid values + assert SlidingWindowConversationManager(per_turn=False).per_turn is False + assert SlidingWindowConversationManager(per_turn=True).per_turn is True + assert SlidingWindowConversationManager(per_turn=3).per_turn == 3 + + +def test_conversation_manager_is_hook_provider(): + """Test that ConversationManager implements HookProvider protocol.""" + manager = NullConversationManager() + assert isinstance(manager, HookProvider) + + +def test_derived_class_does_not_need_to_implement_register_hooks(): + """Test that derived classes don't need to override register_hooks for backwards compatibility.""" + from strands.agent.conversation_manager.conversation_manager import ConversationManager + + class MinimalConversationManager(ConversationManager): + """A minimal implementation that only implements abstract methods.""" + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + pass + + # Should be able to instantiate without implementing register_hooks + manager = MinimalConversationManager() + registry = HookRegistry() + + # Should work without error + manager.register_hooks(registry) + assert not registry.has_callbacks() + + +def test_per_turn_hooks_registration(): + """Test that hooks are registered when conversation_manager implements HookProvider.""" + manager = SlidingWindowConversationManager(per_turn=True) + assert isinstance(manager, HookProvider) + + registry = HookRegistry() + manager.register_hooks(registry) + assert registry.has_callbacks() + + +def test_per_turn_false_no_management_during_loop(): + """Test that per_turn=False only manages in finally block.""" + manager = SlidingWindowConversationManager(per_turn=False, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should only be called once in finally block (per_turn disabled) + assert mock.call_count == 1 + + +def test_per_turn_true_manages_each_model_call(): + """Test that per_turn=True applies management before each model call.""" + manager = SlidingWindowConversationManager(per_turn=True, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should be called for each model call + finally block + # With simple text responses, agent makes 1 model call then stops + assert mock.call_count >= 1 + + +def test_per_turn_integer_manages_every_n_calls(): + """Test that per_turn=N applies management every N model calls.""" + manager = SlidingWindowConversationManager(per_turn=2, window_size=100) + # Create responses that trigger multiple model calls + responses = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]} + for i in range(5) + ] + [{"role": "assistant", "content": [{"text": "Done"}]}] + model = MockedModelProvider(responses) + + @tool(name="test") + def test_tool(query: str = "") -> str: + return "result" + + agent = Agent(model=model, conversation_manager=manager, tools=[test_tool]) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally + assert mock.call_count == 4 + + +def test_per_turn_dynamic_change(): + """Test that per_turn can be changed dynamically.""" + manager = SlidingWindowConversationManager(per_turn=False) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [] + event = BeforeModelCallEvent(agent=mock_agent) + + # Initially disabled + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 0 + + # Enable dynamically + manager.per_turn = True + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 1 + + +def test_per_turn_reduces_message_count(): + """Test that per_turn actually reduces message count during execution.""" + manager = SlidingWindowConversationManager(per_turn=1, window_size=4) + responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + message_counts = [] + original_apply = manager.apply_management + + def track_apply(agent_instance): + message_counts.append(len(agent_instance.messages)) + return original_apply(agent_instance) + + with patch.object(manager, "apply_management", side_effect=track_apply): + agent("Test") + + # Verify message count stayed around window_size + assert any(count <= manager.window_size for count in message_counts) + + +def test_per_turn_state_persistence(): + """Test that model_call_count is persisted in state.""" + manager = SlidingWindowConversationManager(per_turn=3) + manager._model_call_count = 7 + + state = manager.get_state() + assert state["model_call_count"] == 7 + + new_manager = SlidingWindowConversationManager(per_turn=3) + new_manager.restore_from_session(state) + assert new_manager._model_call_count == 7 + + +def test_per_turn_backward_compatibility(): + """Test that existing code without per_turn still works.""" + manager = SlidingWindowConversationManager(window_size=40) + assert manager.per_turn is False + + responses = [{"role": "assistant", "content": [{"text": "Hello"}]}] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + result = agent("Hello") + assert result is not None From c73e9e523580126ad7b6e80b2e712ef6fbee92ea Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Mon, 29 Dec 2025 08:15:00 -0800 Subject: [PATCH 030/279] fix: check api breaking change against main (#1397) --- .github/workflows/pr-and-push.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 1045c2c9e..62c3bfc02 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -33,8 +33,8 @@ jobs: uses: astral-sh/setup-uv@v6 - name: Check API breaking changes run: | - if ! uvx griffe check --search src --format github strands --against "${{ github.event.pull_request.base.sha || github.event.before || 'HEAD' }}"; then - echo "Breaking API changes detected" + if ! uvx griffe check --search src --format github strands --against "main"; then + echo "Potential API changes detected (review if actually breaking)" exit 1 fi From 3b424d0164ddb35c1ed77f58855c016db8bcfc3d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Dec 2025 11:21:18 -0500 Subject: [PATCH 031/279] ci: bump astral-sh/setup-uv from 6 to 7 (#1390) --- .github/workflows/pr-and-push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 62c3bfc02..6269da725 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -30,7 +30,7 @@ jobs: with: fetch-depth: 0 # We the need the full Git history. - name: Setup uv - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 - name: Check API breaking changes run: | if ! uvx griffe check --search src --format github strands --against "main"; then From 067d2595ce7dfc0ae12b5b1fb212bf2d3bfc5516 Mon Sep 17 00:00:00 2001 From: Ratish P <114130421+Ratish1@users.noreply.github.com> Date: Tue, 30 Dec 2025 02:09:26 +0400 Subject: [PATCH 032/279] fix(openai): support tools returning image content (#1079) --- src/strands/models/openai.py | 73 +++++++++- tests/strands/models/test_openai.py | 183 ++++++++++++++++++++++++ tests_integ/models/test_model_openai.py | 1 - 3 files changed, 255 insertions(+), 2 deletions(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 07246c5d6..c381201e4 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -200,6 +200,70 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> "content": [cls.format_request_message_content(content) for content in contents], } + @classmethod + def _split_tool_message_images( + cls, tool_message: dict[str, Any] + ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]: + """Split a tool message into text-only tool message and optional user message with images. + + OpenAI API restricts images to user role messages only. This method extracts any image + content from a tool message and returns it separately as a user message. + + Args: + tool_message: A formatted tool message that may contain images. + + Returns: + A tuple of (tool_message_without_images, user_message_with_images_or_None). + """ + if tool_message.get("role") != "tool": + return tool_message, None + + content = tool_message.get("content", []) + if not isinstance(content, list): + return tool_message, None + + # Separate image and non-image content + text_content = [] + image_content = [] + + for item in content: + if isinstance(item, dict) and item.get("type") == "image_url": + image_content.append(item) + else: + text_content.append(item) + + # If no images found, return original message + if not image_content: + return tool_message, None + + # Let the user know that we are modifying the messages for OpenAI compatibility + logger.warning( + "tool_call_id=<%s> | Moving image from tool message to a new user message for OpenAI compatibility", + tool_message["tool_call_id"], + ) + + # Append a message to the text content to inform the model about the upcoming image + text_content.append( + { + "type": "text", + "text": ( + "Tool successfully returned an image. The image is being provided in the following user message." + ), + } + ) + + # Create the clean tool message with the updated text content + tool_message_clean = { + "role": "tool", + "tool_call_id": tool_message["tool_call_id"], + "content": text_content, + } + + # Create user message with only images + user_message_with_images = {"role": "user", "content": image_content} + + return tool_message_clean, user_message_with_images + @classmethod def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: """Format a tool choice for OpenAI compatibility. @@ -295,7 +359,14 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), } formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) + + # Process tool messages to extract images into separate user messages + # OpenAI API requires images to be in user role messages only + for tool_msg in formatted_tool_messages: + tool_msg_clean, user_msg_with_images = cls._split_tool_message_images(tool_msg) + formatted_messages.append(tool_msg_clean) + if user_msg_with_images: + formatted_messages.append(user_msg_with_images) return formatted_messages diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index ef173d349..7c1d18998 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -179,6 +179,189 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_split_tool_message_images_with_image(): + """Test that images are extracted from tool messages.""" + tool_message = { + "role": "tool", + "tool_call_id": "c1", + "content": [ + {"type": "text", "text": "Result"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,iVBORw0KGgo=", "detail": "auto", "format": "image/png"}, + }, + ], + } + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + # Tool message should now have the original text plus the appended informational text + assert tool_clean["role"] == "tool" + assert tool_clean["tool_call_id"] == "c1" + assert len(tool_clean["content"]) == 2 + assert tool_clean["content"][0]["type"] == "text" + assert tool_clean["content"][0]["text"] == "Result" + assert "Tool successfully returned an image" in tool_clean["content"][1]["text"] + + # User message should have the image + assert user_with_image is not None + assert user_with_image["role"] == "user" + assert len(user_with_image["content"]) == 1 + assert user_with_image["content"][0]["type"] == "image_url" + + +def test_split_tool_message_images_without_image(): + """Test that tool messages without images are unchanged.""" + tool_message = {"role": "tool", "tool_call_id": "c1", "content": [{"type": "text", "text": "Result"}]} + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + assert tool_clean == tool_message + assert user_with_image is None + + +def test_split_tool_message_images_only_image(): + """Test tool message with only image content.""" + tool_message = { + "role": "tool", + "tool_call_id": "c1", + "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}}], + } + + tool_clean, user_with_image = OpenAIModel._split_tool_message_images(tool_message) + + # Tool message should have default text + assert tool_clean["role"] == "tool" + assert len(tool_clean["content"]) == 1 + assert "successfully" in tool_clean["content"][0]["text"].lower() + + # User message should have the image + assert user_with_image is not None + assert user_with_image["role"] == "user" + assert len(user_with_image["content"]) == 1 + + +def test_split_tool_message_images_non_tool_role(): + """Test that messages with roles other than 'tool' are ignored.""" + user_msg = {"role": "user", "content": [{"type": "text", "text": "hello"}]} + clean, extra = OpenAIModel._split_tool_message_images(user_msg) + assert clean == user_msg + assert extra is None + + +def test_split_tool_message_images_invalid_content_type(): + """Test that messages with non-list content are ignored.""" + invalid_msg = {"role": "tool", "content": "not a list"} + clean, extra = OpenAIModel._split_tool_message_images(invalid_msg) + assert clean == invalid_msg + assert extra is None + + +def test_format_request_messages_with_tool_result_containing_image(): + """Test that tool results with images are properly split into tool and user messages.""" + messages = [ + { + "content": [{"text": "Run the tool"}], + "role": "user", + }, + { + "content": [ + { + "toolUse": { + "input": {}, + "name": "image_tool", + "toolUseId": "t1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + {"text": "Image generated"}, + { + "image": { + "format": "png", + "source": {"bytes": b"fake_image_data"}, + } + }, + ], + } + } + ], + "role": "user", + }, + ] + + formatted = OpenAIModel.format_request_messages(messages) + + # Find the tool message + tool_messages = [msg for msg in formatted if msg.get("role") == "tool"] + assert len(tool_messages) == 1 + + # Tool message should only have text content + tool_msg = tool_messages[0] + assert all(c.get("type") != "image_url" for c in tool_msg["content"]) + + # There should be a user message right after the tool message with the image + tool_msg_idx = formatted.index(tool_msg) + assert tool_msg_idx + 1 < len(formatted) + user_msg = formatted[tool_msg_idx + 1] + assert user_msg["role"] == "user" + assert any(c.get("type") == "image_url" for c in user_msg["content"]) + + +def test_format_request_messages_with_multiple_images_in_tool_result(): + """Test tool result with multiple images.""" + messages = [ + { + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [ + {"text": "Two images generated"}, + { + "image": { + "format": "png", + "source": {"bytes": b"image1"}, + } + }, + { + "image": { + "format": "jpg", + "source": {"bytes": b"image2"}, + } + }, + ], + } + } + ], + "role": "user", + }, + ] + + formatted = OpenAIModel.format_request_messages(messages) + + # Find user message with images + user_image_msgs = [ + msg + for msg in formatted + if msg.get("role") == "user" and any(c.get("type") == "image_url" for c in msg.get("content", [])) + ] + assert len(user_image_msgs) == 1 + + # Should have both images + image_contents = [c for c in user_image_msgs[0]["content"] if c.get("type") == "image_url"] + assert len(image_contents) == 2 + + def test_format_request_tool_choice_auto(): tool_choice = {"auto": {}} diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index feb591d1a..503fca898 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -148,7 +148,6 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): assert tru_color == exp_color -@pytest.mark.skip("https://github.com/strands-agents/sdk-python/issues/320") def test_tool_returning_images(model, yellow_img): @tool def tool_with_image_return(): From e4f27c66b33f6bfa609c8da26e7a5d4efdd9beab Mon Sep 17 00:00:00 2001 From: poshinchen Date: Mon, 29 Dec 2025 23:59:43 -0500 Subject: [PATCH 033/279] feat: added agent_invocations (#1387) --- src/strands/agent/agent.py | 2 + src/strands/telemetry/metrics.py | 120 +++++++++++++++--- tests/strands/event_loop/test_event_loop.py | 1 + .../test_event_loop_structured_output.py | 1 + tests/strands/telemetry/test_metrics.py | 94 +++++++++++++- 5 files changed, 195 insertions(+), 23 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 256c74415..9e726ca0b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -565,6 +565,8 @@ async def stream_async( """ self._interrupt_state.resume(prompt) + self.event_loop_metrics.reset_usage_metrics() + merged_state = {} if kwargs: warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index abfbbffae..8f3ee1ea1 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -151,6 +151,34 @@ def add_call( metrics_client.tool_error_count.add(1, attributes=attributes) +@dataclass +class EventLoopCycleMetric: + """Aggregated metrics for a single event loop cycle. + + Attributes: + event_loop_cycle_id: Current eventLoop cycle id. + usage: Total token usage for the entire cycle (succeeded model invocation, excluding tool invocations). + """ + + event_loop_cycle_id: str + usage: Usage + + +@dataclass +class AgentInvocation: + """Metrics for a single agent invocation. + + AgentInvocation contains all the event loop cycles and accumulated token usage for that invocation. + + Attributes: + cycles: List of event loop cycles that occurred during this invocation. + usage: Accumulated token usage for this invocation across all cycles. + """ + + cycles: list[EventLoopCycleMetric] = field(default_factory=list) + usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) + + @dataclass class EventLoopMetrics: """Aggregated metrics for an event loop's execution. @@ -159,15 +187,17 @@ class EventLoopMetrics: cycle_count: Number of event loop cycles executed. tool_metrics: Metrics for each tool used, keyed by tool name. cycle_durations: List of durations for each cycle in seconds. + agent_invocations: Agent invocation metrics containing cycles and usage data. traces: List of execution traces. - accumulated_usage: Accumulated token usage across all model invocations. + accumulated_usage: Accumulated token usage across all model invocations (across all requests). accumulated_metrics: Accumulated performance metrics across all model invocations. """ cycle_count: int = 0 - tool_metrics: Dict[str, ToolMetrics] = field(default_factory=dict) - cycle_durations: List[float] = field(default_factory=list) - traces: List[Trace] = field(default_factory=list) + tool_metrics: dict[str, ToolMetrics] = field(default_factory=dict) + cycle_durations: list[float] = field(default_factory=list) + agent_invocations: list[AgentInvocation] = field(default_factory=list) + traces: list[Trace] = field(default_factory=list) accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) @@ -176,14 +206,23 @@ def _metrics_client(self) -> "MetricsClient": """Get the singleton MetricsClient instance.""" return MetricsClient() + @property + def latest_agent_invocation(self) -> Optional[AgentInvocation]: + """Get the most recent agent invocation. + + Returns: + The most recent AgentInvocation, or None if no invocations exist. + """ + return self.agent_invocations[-1] if self.agent_invocations else None + def start_cycle( self, - attributes: Optional[Dict[str, Any]] = None, + attributes: Dict[str, Any], ) -> Tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. Args: - attributes: attributes of the metrics. + attributes: attributes of the metrics, including event_loop_cycle_id. Returns: A tuple containing the start time and the cycle trace object. @@ -194,6 +233,14 @@ def start_cycle( start_time = time.time() cycle_trace = Trace(f"Cycle {self.cycle_count}", start_time=start_time) self.traces.append(cycle_trace) + + self.agent_invocations[-1].cycles.append( + EventLoopCycleMetric( + event_loop_cycle_id=attributes["event_loop_cycle_id"], + usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + ) + ) + return start_time, cycle_trace def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: @@ -252,32 +299,53 @@ def add_tool_usage( ) tool_trace.end() + def _accumulate_usage(self, target: Usage, source: Usage) -> None: + """Helper method to accumulate usage from source to target. + + Args: + target: The Usage object to accumulate into. + source: The Usage object to accumulate from. + """ + target["inputTokens"] += source["inputTokens"] + target["outputTokens"] += source["outputTokens"] + target["totalTokens"] += source["totalTokens"] + + if "cacheReadInputTokens" in source: + target["cacheReadInputTokens"] = target.get("cacheReadInputTokens", 0) + source["cacheReadInputTokens"] + + if "cacheWriteInputTokens" in source: + target["cacheWriteInputTokens"] = target.get("cacheWriteInputTokens", 0) + source["cacheWriteInputTokens"] + def update_usage(self, usage: Usage) -> None: """Update the accumulated token usage with new usage data. Args: usage: The usage data to add to the accumulated totals. """ + # Record metrics to OpenTelemetry self._metrics_client.event_loop_input_tokens.record(usage["inputTokens"]) self._metrics_client.event_loop_output_tokens.record(usage["outputTokens"]) - self.accumulated_usage["inputTokens"] += usage["inputTokens"] - self.accumulated_usage["outputTokens"] += usage["outputTokens"] - self.accumulated_usage["totalTokens"] += usage["totalTokens"] - # Handle optional cached token metrics + # Handle optional cached token metrics for OpenTelemetry if "cacheReadInputTokens" in usage: - cache_read_tokens = usage["cacheReadInputTokens"] - self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens) - self.accumulated_usage["cacheReadInputTokens"] = ( - self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens - ) - + self._metrics_client.event_loop_cache_read_input_tokens.record(usage["cacheReadInputTokens"]) if "cacheWriteInputTokens" in usage: - cache_write_tokens = usage["cacheWriteInputTokens"] - self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens) - self.accumulated_usage["cacheWriteInputTokens"] = ( - self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens - ) + self._metrics_client.event_loop_cache_write_input_tokens.record(usage["cacheWriteInputTokens"]) + + self._accumulate_usage(self.accumulated_usage, usage) + self._accumulate_usage(self.agent_invocations[-1].usage, usage) + + if self.agent_invocations[-1].cycles: + current_cycle = self.agent_invocations[-1].cycles[-1] + self._accumulate_usage(current_cycle.usage, usage) + + def reset_usage_metrics(self) -> None: + """Start a new agent invocation by creating a new AgentInvocation. + + This should be called at the start of a new request to begin tracking + a new agent invocation with fresh usage and cycle data. + """ + self.agent_invocations.append(AgentInvocation()) def update_metrics(self, metrics: Metrics) -> None: """Update the accumulated performance metrics with new metrics data. @@ -322,6 +390,16 @@ def get_summary(self) -> Dict[str, Any]: "traces": [trace.to_dict() for trace in self.traces], "accumulated_usage": self.accumulated_usage, "accumulated_metrics": self.accumulated_metrics, + "agent_invocations": [ + { + "usage": invocation.usage, + "cycles": [ + {"event_loop_cycle_id": cycle.event_loop_cycle_id, "usage": cycle.usage} + for cycle in invocation.cycles + ], + } + for invocation in self.agent_invocations + ], } return summary diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 52980729c..6b23bd592 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -142,6 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_registry = tool_registry mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() mock.hooks = hook_registry mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 508042af0..23b7f3433 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -37,6 +37,7 @@ def mock_agent(): agent.messages = [] agent.tool_registry = ToolRegistry() agent.event_loop_metrics = EventLoopMetrics() + agent.event_loop_metrics.reset_usage_metrics() agent.hooks = Mock() agent.hooks.invoke_callbacks_async = AsyncMock() agent.trace_span = None diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index e87277eed..800bcebc4 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -240,9 +240,15 @@ def test_tool_metrics_add_call(success, tool, tool_metrics, mock_get_meter_provi @unittest.mock.patch.object(strands.telemetry.metrics.uuid, "uuid4") def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metrics, mock_get_meter_provider): mock_time.return_value = 1 - mock_uuid4.return_value = "i1" + mock_event_loop_cycle_id = "i1" + mock_uuid4.return_value = mock_event_loop_cycle_id - tru_start_time, tru_cycle_trace = event_loop_metrics.start_cycle() + # Reset must be called first + event_loop_metrics.reset_usage_metrics() + + tru_start_time, tru_cycle_trace = event_loop_metrics.start_cycle( + attributes={"event_loop_cycle_id": mock_event_loop_cycle_id} + ) exp_start_time, exp_cycle_trace = 1, strands.telemetry.metrics.Trace("Cycle 1") tru_attrs = {"cycle_count": event_loop_metrics.cycle_count, "traces": event_loop_metrics.traces} @@ -256,6 +262,13 @@ def test_event_loop_metrics_start_cycle(mock_uuid4, mock_time, event_loop_metric and tru_attrs == exp_attrs ) + assert len(event_loop_metrics.agent_invocations) == 1 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 1 + assert event_loop_metrics.agent_invocations[0].cycles[0].event_loop_cycle_id == "i1" + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["inputTokens"] == 0 + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["outputTokens"] == 0 + assert event_loop_metrics.agent_invocations[0].cycles[0].usage["totalTokens"] == 0 + @unittest.mock.patch.object(strands.telemetry.metrics.time, "time") def test_event_loop_metrics_end_cycle(mock_time, trace, event_loop_metrics, mock_get_meter_provider): @@ -324,6 +337,9 @@ def test_event_loop_metrics_add_tool_usage(mock_time, trace, tool, event_loop_me def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "test-cycle"}) + for _ in range(3): event_loop_metrics.update_usage(usage) @@ -331,6 +347,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=6) assert tru_usage == exp_usage + + assert event_loop_metrics.latest_agent_invocation.usage == exp_usage + + assert len(event_loop_metrics.agent_invocations) == 1 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 1 + assert event_loop_metrics.agent_invocations[0].cycles[0].event_loop_cycle_id == "test-cycle" + assert event_loop_metrics.agent_invocations[0].cycles[0].usage == exp_usage + mock_get_meter_provider.return_value.get_meter.assert_called() metrics_client = event_loop_metrics._metrics_client metrics_client.event_loop_input_tokens.record.assert_called() @@ -370,6 +394,7 @@ def test_event_loop_metrics_get_summary(trace, tool, event_loop_metrics, mock_ge "outputTokens": 0, "totalTokens": 0, }, + "agent_invocations": [], "average_cycle_time": 0, "tool_usage": { "tool1": { @@ -476,3 +501,68 @@ def test_use_ProxyMeter_if_no_global_meter_provider(): # Verify it's using a _ProxyMeter assert isinstance(metrics_client.meter, _ProxyMeter) + + +def test_latest_agent_invocation_property(usage, event_loop_metrics, mock_get_meter_provider): + """Test the latest_agent_invocation property getter""" + # Initially, no invocations exist + assert event_loop_metrics.latest_agent_invocation is None + + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-1"}) + event_loop_metrics.update_usage(usage) + + # latest_agent_invocation should return the first invocation + current = event_loop_metrics.latest_agent_invocation + assert current is not None + assert current.usage["inputTokens"] == 1 + assert len(current.cycles) == 1 + + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-2"}) + usage2 = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + event_loop_metrics.update_usage(usage2) + + # Should return the second invocation + current = event_loop_metrics.latest_agent_invocation + assert current is not None + assert current.usage["inputTokens"] == 10 + assert len(current.cycles) == 1 + + assert len(event_loop_metrics.agent_invocations) == 2 + + assert current is event_loop_metrics.agent_invocations[-1] + + +def test_reset_usage_metrics(usage, event_loop_metrics, mock_get_meter_provider): + """Test that reset_usage_metrics creates a new agent invocation but preserves accumulated_usage""" + # Add some usage across multiple cycles in first invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-1"}) + event_loop_metrics.update_usage(usage) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "cycle-2"}) + usage2 = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + event_loop_metrics.update_usage(usage2) + + assert len(event_loop_metrics.agent_invocations) == 1 + assert event_loop_metrics.latest_agent_invocation.usage["inputTokens"] == 11 + assert len(event_loop_metrics.latest_agent_invocation.cycles) == 2 + assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 + + # Reset - creates a new invocation + event_loop_metrics.reset_usage_metrics() + + assert len(event_loop_metrics.agent_invocations) == 2 + + assert event_loop_metrics.latest_agent_invocation.usage["inputTokens"] == 0 + assert event_loop_metrics.latest_agent_invocation.usage["outputTokens"] == 0 + assert event_loop_metrics.latest_agent_invocation.usage["totalTokens"] == 0 + assert len(event_loop_metrics.latest_agent_invocation.cycles) == 0 + + # Verify first invocation data is preserved + assert event_loop_metrics.agent_invocations[0].usage["inputTokens"] == 11 + assert len(event_loop_metrics.agent_invocations[0].cycles) == 2 + + # Verify accumulated_usage is NOT cleared + assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 From e980d9870b07cc32bfcb35aa03d0292b00a2fbb5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Dec 2025 14:57:42 -0500 Subject: [PATCH 034/279] ci: bump actions/checkout from 5 to 6 (#1389) Bumps [actions/checkout](https://github.com/actions/checkout) from 5 to 6. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pr-and-push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 6269da725..d2af9f956 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -26,7 +26,7 @@ jobs: contents: read steps: - name: Checkout code - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: fetch-depth: 0 # We the need the full Git history. - name: Setup uv From c6a56ad92597bcbffc0dac023833a155c1fd7156 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:07:35 -0500 Subject: [PATCH 035/279] Re-add agents to this PR (#1403) Co-authored-by: Mackenzie Zastrow --- .github/actions/README.md | 285 ++++++ .../actions/strands-agent-runner/action.yml | 179 ++++ .../actions/strands-write-executor/action.yml | 147 +++ .github/agent-sops/task-implementer.sop.md | 493 ++++++++++ .github/agent-sops/task-refiner.sop.md | 298 +++++++ .github/agent-sops/task-release-notes.sop.md | 586 ++++++++++++ .github/scripts/javascript/process-input.cjs | 125 +++ .github/scripts/python/agent_runner.py | 164 ++++ .github/scripts/python/github_tools.py | 843 ++++++++++++++++++ .github/scripts/python/handoff_to_user.py | 34 + .github/scripts/python/notebook.py | 337 +++++++ .github/scripts/python/requirements.txt | 8 + .../python/str_replace_based_edit_tool.py | 230 +++++ .github/scripts/python/write_executor.py | 152 ++++ .github/workflows/strands-command.yml | 184 ++++ 15 files changed, 4065 insertions(+) create mode 100644 .github/actions/README.md create mode 100644 .github/actions/strands-agent-runner/action.yml create mode 100644 .github/actions/strands-write-executor/action.yml create mode 100644 .github/agent-sops/task-implementer.sop.md create mode 100644 .github/agent-sops/task-refiner.sop.md create mode 100644 .github/agent-sops/task-release-notes.sop.md create mode 100644 .github/scripts/javascript/process-input.cjs create mode 100644 .github/scripts/python/agent_runner.py create mode 100644 .github/scripts/python/github_tools.py create mode 100644 .github/scripts/python/handoff_to_user.py create mode 100644 .github/scripts/python/notebook.py create mode 100644 .github/scripts/python/requirements.txt create mode 100644 .github/scripts/python/str_replace_based_edit_tool.py create mode 100755 .github/scripts/python/write_executor.py create mode 100644 .github/workflows/strands-command.yml diff --git a/.github/actions/README.md b/.github/actions/README.md new file mode 100644 index 000000000..a3ec3fa2d --- /dev/null +++ b/.github/actions/README.md @@ -0,0 +1,285 @@ +# Strands Command GitHub Actions + +A comprehensive AI agent execution system for GitHub repositories that processes `/strands` commands in issues and pull requests. + +## Overview + +The Strands Command system enables AI-powered automation in GitHub repositories through: + +- **Issue Comment Processing**: Responds to `/strands` commands in issues and PRs +- **Controlled AI Execution**: Runs AI agents with read-only and write-separated permissions +- **AWS Integration**: Secure OIDC-based authentication with Bedrock AI models +- **Security-First Design**: Manual approval gates and permission isolation + +### Architecture + +```mermaid +graph LR + A["strands Command"] --> B[Authorization] + B --> C[Read-Only Agent] + C --> D[Write Operations] + D --> E[Cleanup] + + B -.-> B1[Permission Check] + C -.-> C1[AWS + AI Execution] + D -.-> D1[Repository Updates] +``` + +## Quick Start + +1. **Set up AWS IAM Role** (see [IAM Role Policy](#iam-role-policy)) +2. **Configure GitHub Secrets**: + - `AWS_ROLE_ARN`: Your IAM role ARN + - `STRANDS_SESSION_BUCKET`: S3 bucket for session storage +3. **Copy required files** to your repository: + - `.github/workflows/strands-command.yml` + - `.github/actions/` directory + - `.github/scripts/` directory + - `.github/agent-sops/` directory +4. **Comment `/strands [your task]`** on any issue or PR + - **On Issues**: + - Use `/strands ` to have an agent help you refine an issue within the context of the current github repo + - Use `/strands implement ` to create a new PR based on the description of an issue + - **On PRs**: `/strands ` will instruct an Agent to review PR comments and make updates to the issue + +## Actions + +### strands-agent-runner + +Executes AI agents with AWS integration and controlled permissions. + +**Inputs:** +- `ref` (required): Git reference to checkout +- `system_prompt` (required): System instructions for the agent +- `session_id` (required): Session identifier for persistence +- `task_prompt` (required): Task description for the agent +- `aws_role_arn` (required): AWS IAM role ARN for authentication +- `sessions_bucket` (required): S3 bucket for session storage +- `write_permission` (required): Permission level flag for Read-only Sandbox mode (`true`/`false`) + +**Features:** +- Strands Agent running with Agent SOPs specifically designed to instruct an Agent on how to develop in Github +- Python 3.13 and Node.js 20 environment setup (Node.js setup and npm install are optional and can be removed - only included for this repo's development) +- Read-only Sandbox support: Agent write actions can be deferred to the `strands-write-executor` action if you want your agent to execute with read-only github permissions + +### strands-write-executor + +Executes write operations from agent-generated artifacts if `strands-agent-runner` was run with `write_permissions: false`. + +**Inputs:** +- `ref` (required): Target branch for changes +- `issue_id` (optional): Associated issue number + +**Features:** +- Reads Agent modified repository state from artifacts, and pushes changes to pr branch +- Reads deferred write operations from artifact and executes them + +## Workflows + +### strands-command.yml + +Main workflow that orchestrates the complete Strands command execution: + +1. **Authorization Check**: Validates user permissions and applies approval gates +2. **Setup and Processing**: Parses input and prepares execution context +3. **Read-Only Execution**: Runs Agent in Read-only sandbox +4. **Write Operations**: Executes repository modifications in job isolated from agent +5. **Cleanup**: Removes temporary labels and artifacts + +**Triggers:** +- Issue comments starting with `/strands` +- Manual workflow dispatch with parameters + +## Agent SOPs + +### Task Implementer (`task-implementer.sop.md`) + +Implements features using test-driven development principles. + +**Workflow**: Setup → Explore → Plan → Code → Commit → Pull Request + +**Capabilities:** +- Feature implementation with TDD approach +- Comprehensive testing and documentation +- Pull request creation and iteration +- Code pattern following and best practices + +### Task Refiner (`task-refiner.sop.md`) + +Refines and clarifies task requirements before implementation. + +**Workflow**: Read Issue → Analyze → Research → Clarify → Iterate + +**Capabilities:** +- Requirement analysis and gap identification +- Clarifying question generation +- Implementation planning and preparation +- Ambiguity resolution through user interaction + +## IAM Role Policy + +### Required IAM Role + +Create an IAM role with the following trust policy for GitHub OIDC: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Federated": "arn:aws:iam::YOUR_ACCOUNT_ID:oidc-provider/token.actions.githubusercontent.com" + }, + "Action": "sts:AssumeRoleWithWebIdentity", + "Condition": { + "StringEquals": { + "token.actions.githubusercontent.com:aud": "sts.amazonaws.com" + }, + "StringLike": { + "token.actions.githubusercontent.com:sub": "repo:YOUR_ORG/YOUR_REPO:*" + } + } + } + ] +} +``` + +### IAM Role Policy + +Your IAM role must have these permissions in order to execute: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "Bedrock Access", + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModelWithResponseStream", + "bedrock:InvokeModel" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": [ + "s3:PutObject", + "s3:GetObject", + "s3:DeleteObject" + ], + "Resource": [ + "arn:aws:s3:::YOUR_STRANDS_SESSION_BUCKET/*" + ] + }, + { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": [ + "arn:aws:s3:::YOUR_STRANDS_SESSION_BUCKET" + ] + } + ] +} +``` + +### Setup Steps + +1. **Create OIDC Provider** (if not exists): + ```bash + aws iam create-open-id-connect-provider \ + --url https://token.actions.githubusercontent.com \ + --thumbprint-list 6938fd4d98bab03faadb97b34396831e3780aea1 \ + --client-id-list sts.amazonaws.com + ``` + +2. **Create IAM Role** with the trust policy above +3. **Create S3 Bucket** for session storage +4. **Add GitHub Secrets**: + - `AWS_ROLE_ARN`: The created role ARN + - `STRANDS_SESSION_BUCKET`: The S3 bucket name + +## Security + +### ⚠️ Important Security Considerations + +**This workflow should only be used with trusted sources and should use AWS guardrails to help avoid prompt injection risks.** + +### Security Features + +#### Authorization Controls +- **Collaborator Verification**: Only users with write access get auto-approval +- **Manual Approval Gates**: Unknown users require manual approval via GitHub environments +- **Permission Separation**: Read and write operations isolated in separate jobs + +#### AWS Security +- **OIDC Authentication**: No long-lived credentials stored in GitHub +- **Minimal Permissions**: Inline session policy limits access to required resources only +- **Temporary Credentials**: Each execution gets fresh, time-limited AWS credentials. You can further limit these by updating the `strands-agent-runner` "Configure AWS credentials" step, and set the `role-duration-seconds` value +- **Resource Scoping**: S3 access limited to specific session bucket + +#### Prompt Injection Mitigation +- **Trusted Sources Only**: Implement strict user authorization +- **AWS Guardrails**: Use AWS Bedrock guardrails to filter malicious prompts +- **Input Validation**: Validate and sanitize all user inputs +- **Execution Isolation**: Separate read and write phases prevent unauthorized modifications + +## Configuration + +### GitHub Secrets + +| Secret | Description | Example | +|--------|-------------|---------| +| `AWS_ROLE_ARN` | IAM role for AWS access | `arn:aws:iam::123456789012:role/GitHubActionsRole` | +| `STRANDS_SESSION_BUCKET` | S3 bucket for sessions | `my-strands-sessions-bucket` | + +### Environment Variables + +The actions use these environment variables during execution: + +| Variable | Purpose | Set By | +|----------|---------|--------| +| `GITHUB_WRITE` | Permission level indicator | Action | +| `SESSION_ID` | Agent session identifier | Workflow | +| `S3_SESSION_BUCKET` | Session storage location | Input | +| `STRANDS_TOOL_CONSOLE_MODE` | Tool execution mode | Action | +| `BYPASS_TOOL_CONSENT` | Automated tool approval | Action | + +## Usage Examples + +### Basic Task Implementation + +Comment on an issue: +``` +/strands Implement a new user authentication feature with JWT tokens +``` + +### Task Refinement + +Comment on an issue with unclear requirements: +``` +/strands refine Please help clarify the requirements for this feature +``` + +### Manual Execution + +Use workflow dispatch with: +- **issue_id**: `123` +- **command**: `Implement the requested feature` +- **session_id**: `optional-session-id` + +### Advanced Usage + +``` +/strands implement Create a REST API endpoint for user management with the following requirements: +1. CRUD operations for users +2. JWT authentication +3. Input validation +4. Unit tests with 90% coverage +5. OpenAPI documentation +``` + +--- + +**Note**: This system is designed for trusted environments. Always review security implications before deployment and implement appropriate guardrails for your use case. diff --git a/.github/actions/strands-agent-runner/action.yml b/.github/actions/strands-agent-runner/action.yml new file mode 100644 index 000000000..6d4c2d7fb --- /dev/null +++ b/.github/actions/strands-agent-runner/action.yml @@ -0,0 +1,179 @@ +name: 'Strands Agent Runner' +description: 'Execute a Strands agent with the given prompts and configuration' +inputs: + ref: + description: 'ref to checkout' + required: true + system_prompt: + description: 'System prompt for the agent' + required: true + session_id: + description: 'Session ID for the agent execution' + required: true + task_prompt: + description: 'Task prompt for the agent' + required: true + aws_role_arn: + description: 'AWS IAM role ARN for authentication' + required: true + sessions_bucket: + description: 'S3 bucket for session storage' + required: true + write_permission: + description: 'If this action runs with write permission. If this is false, you should run the `strands-write-executor` action after this one with write permission.' + required: true + default: 'false' + +runs: + using: 'composite' + steps: + # Checkout main repo .github directory + - name: Checkout repository + uses: actions/checkout@v5 + with: + sparse-checkout: | + .github + + # Copy the .github directory to the runner temp directory so the branch content cant overwrite the scripts executed here + - name: Copy .github to safe directory + shell: bash + run: | + mkdir -p ${{ runner.temp }}/strands-agent-runner + cp -r .github ${{ runner.temp }}/strands-agent-runner + + # Checkout the branch repo to stage the directory for the agent + - name: Checkout repository + uses: actions/checkout@v5 + with: + ref: ${{ inputs.ref }} + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + + - name: Install dependencies + # If we have package.json then install the dependencies - this is for compatibility in multiple repos + if: hashFiles('package.json') != '' + shell: bash + run: npm install + continue-on-error: true # This step's failure will not stop the workflow + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.13' + + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: '${{ runner.temp }}/strands-agent-runner/.github/scripts/python/requirements.txt' + + - name: Install Strands Agents + shell: bash + run: | + echo "📦 Installing from requirements.txt" + uv pip install --system -r ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/requirements.txt --quiet + + - name: Configure Git + shell: bash + run: | + git config --global user.name "Strands Agent" + git config --global user.email "217235299+strands-agent@users.noreply.github.com" + git config --global core.pager cat + PAGER=cat + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ inputs.aws_role_arn }} + role-session-name: GitHubActions-StrandsAgent-${{ github.run_id }} + aws-region: us-west-2 + mask-aws-account-id: true + inline_session_policy: >- + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid":"Bedrock Access", + "Effect": "Allow", + "Action": [ + "bedrock:InvokeModelWithResponseStream", + "bedrock:InvokeModel" + ], + "Resource": "*" + }, { + "Effect": "Allow", + "Action": [ + "s3:PutObject", + "s3:GetObject", + "s3:DeleteObject", + ], + "Resource": [ + "arn:aws:s3:::strands-typescript-project-sessions/*", + ] + }, { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": [ + "arn:aws:s3:::strands-typescript-project-sessions", + ] + } + ] + } + + + - name: Execute strands command + shell: bash + env: + # Write Permission + GITHUB_WRITE: ${{ inputs.write_permission }} + + # GitHub Configuration + GITHUB_TOKEN: ${{ github.token }} + GITHUB_REPOSITORY: ${{ github.repository }} + + # Task Configuration + INPUT_TASK: ${{ inputs.task_prompt }} + INPUT_SYSTEM_PROMPT: ${{ inputs.system_prompt }} + + # AWS Configuration + AWS_REGION: 'us-west-2' + + # Session Manager + S3_SESSION_BUCKET: ${{ inputs.sessions_bucket }} + SESSION_ID: ${{ inputs.session_id }} + + # Strands Env Vars + STRANDS_TOOL_CONSOLE_MODE: 'enabled' + BYPASS_TOOL_CONSENT: 'true' + run: | + uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py "$INPUT_TASK" + + - name: Capture repository state + shell: bash + run: | + mkdir -p .artifact + if git diff --quiet HEAD@{upstream} && git diff --quiet --cached; then + echo "📭 No changes to capture" + else + echo "📝 Capturing entire repository state" + tar -czf .artifact/repository_state.tar.gz --exclude='.artifact' . + fi + + - name: Upload repository state artifact + uses: actions/upload-artifact@v4 + with: + name: repository-state + path: .artifact/repository_state.tar.gz + retention-days: 1 + if-no-files-found: ignore + + - name: Upload artifact for write operations + uses: actions/upload-artifact@v4 + with: + name: write-operations + path: .artifact/write_operations.jsonl + retention-days: 1 + if-no-files-found: ignore \ No newline at end of file diff --git a/.github/actions/strands-write-executor/action.yml b/.github/actions/strands-write-executor/action.yml new file mode 100644 index 000000000..3417c3140 --- /dev/null +++ b/.github/actions/strands-write-executor/action.yml @@ -0,0 +1,147 @@ +name: 'Strands Write Executor' +description: 'Execute write GitHub operations from JSONL artifact files during workflow execution' +inputs: + ref: + description: 'Ref to push changes to' + required: true + issue_id: + description: 'Issue ID for fallback operations' + required: false + +runs: + using: 'composite' + steps: + + # Push code changes before running write commands in case we need to create a pull request + # Pull requests cannot be created if a branch has no diff with main, so push changes first, then create pr + - name: Log if ref equals main + shell: bash + run: | + if [ "${{ inputs.ref }}" = "${{ github.event.repository.default_branch }}" ]; then + echo "🚫 Ref is default - skipping push operations to prevent direct commits to default branch" + else + echo "✅ Ref is '${{ inputs.ref }}' - push operations will proceed" + fi + + - name: Download repository state artifact + if: inputs.ref != github.event.repository.default_branch + uses: actions/download-artifact@v4 + with: + name: repository-state + path: ${{ runner.temp }} + continue-on-error: true + + - name: Apply Artifact and Push changes + if: inputs.ref != github.event.repository.default_branch + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + + if [ -f "$RUNNER_TEMP/repository_state.tar.gz" ]; then + echo "📝 Applying repository state" + mkdir -p "$RUNNER_TEMP/temp_git_repo" + tar -xzf "$RUNNER_TEMP/repository_state.tar.gz" -C "$RUNNER_TEMP/temp_git_repo" + rm "$RUNNER_TEMP/repository_state.tar.gz" + + echo "📁 Changing to repository directory" + ORIGINAL_DIRECTORY=$(pwd) + cd "$RUNNER_TEMP/temp_git_repo" + + # Configure Git + git config --local user.name "Strands Agent" + git config --local user.email "217235299+strands-agent@users.noreply.github.com" + git config --local core.pager cat + # We need to overwrite this since this is currently set by the previous readonly workflow artifact + # Overwrite this value with the current token that allows us to push the commit + git config --local http."https://github.com/".extraheader "AUTHORIZATION: basic $(echo -n x-access-token:${{ github.token }}| base64)" + + # Fetch the remote repository + git fetch origin ${{ inputs.ref }} + + # Stage and commit any changes first + if [ -n "$(git status --porcelain)" ]; then + echo "📝 Changes detected, staging all files" + git add -A + echo "📝 Committing changes" + git commit -m "Additional changes from write operations" -n + fi + + # Push if there are differences from remote + if ! git diff --quiet HEAD origin/${{ inputs.ref }}; then + echo "📝 Differences from remote:" + git diff HEAD origin/${{ inputs.ref }} + echo "📤 Pushing changes to ${{ inputs.ref }}" + git push --force origin ${{ inputs.ref }} + else + echo "📭 No changes to push" + fi + + # Change back and clean up + cd $ORIGINAL_DIRECTORY + rm -rf "$RUNNER_TEMP/temp_git_repo" + fi + + - name: Download artifact with write operations + uses: actions/download-artifact@v4 + with: + name: write-operations + continue-on-error: true + + - name: Check if write operations artifact exists + id: check-write-ops + shell: bash + run: | + if [ -f "write_operations.jsonl" ]; then + echo "✅ Write operations artifact exists! Continuing to execute commands!" + cp -r write_operations.jsonl ${{ runner.temp }} + echo "exists=true" >> $GITHUB_OUTPUT + else + echo "❌ Write operations artifact does not exist. Stopping execution." + echo "exists=false" >> $GITHUB_OUTPUT + fi + + - name: Checkout repo to temp dir + if: steps.check-write-ops.outputs.exists == 'true' + uses: actions/checkout@v5 + with: + sparse-checkout: | + .github + + - name: Set up Python + if: steps.check-write-ops.outputs.exists == 'true' + uses: actions/setup-python@v4 + with: + python-version: '3.13' + + - name: Install uv + if: steps.check-write-ops.outputs.exists == 'true' + uses: astral-sh/setup-uv@v3 + with: + enable-cache: true + cache-dependency-glob: ./.github/scripts/python/requirements.txt + + - name: Install dependencies + if: steps.check-write-ops.outputs.exists == 'true' + shell: bash + run: | + echo "📦 Installing from requirements.txt" + uv pip install --system -r ./.github/scripts/python/requirements.txt --quiet + + - name: Execute write operations + if: steps.check-write-ops.outputs.exists == 'true' + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + GITHUB_REPOSITORY: ${{ github.repository }} + + # Strands Env Vars + STRANDS_TOOL_CONSOLE_MODE: 'enabled' + BYPASS_TOOL_CONSENT: 'true' + run: | + echo "🚀 Strands Write Executor - Processing write operations" + if [ -n "${{ inputs.issue_id }}" ]; then + python ./.github/scripts/python/write_executor.py "${{ runner.temp }}/write_operations.jsonl" --issue-id "${{ inputs.issue_id }}" + else + python ./.github/scripts/python/write_executor.py "${{ runner.temp }}/write_operations.jsonl" + fi diff --git a/.github/agent-sops/task-implementer.sop.md b/.github/agent-sops/task-implementer.sop.md new file mode 100644 index 000000000..cc7aa3330 --- /dev/null +++ b/.github/agent-sops/task-implementer.sop.md @@ -0,0 +1,493 @@ +# Task Implementer SOP + +## Role + +You are a Task Implementer, and your goal is to implement a task defined in a github issue. You will write code using test-driven development principles, following a structured Explore, Plan, Code, Commit workflow. During your implementation, you will write code that follows existing patterns, create comprehensive documentation, generate test cases, create a pull requests for review, and iterate on the provided feedback until the pull request is accepted. + +## Steps + +### 1. Setup Task Environment + +Initialize the task environment and discover repository instruction files. + +**Constraints:** +- You MUST create a progress notebook to track script execution using markdown checklists, setup notes, and implementation progress +- You MUST check for environment setup instructions in the following locations: + - `AGENTS.md` + - `DEVELOPMENT.md` + - `CONTRIBUTING.md` + - `README.md` +- You MAY explore more files in the repository if you did not find instructions +- You MUST check the `GITHUB_WRITE` environment variable value to determine if you have github write permission + - If the value is `true`, then you can run git write command like `add_comment` or run `git push` + - If the value is not `true`, you are running in a read-restricted sandbox. Any write commands you do run will be deferred to run outside the sandbox + - Any staged or unstaged changes will be pushed after you finish executing to the feature branch +- You MUST make a note of environment setup and testing instructions +- You MUST make note of the tasks number from the issue title +- You MUST make note of the issue number +- You MUST run unit test to ensure the repository and environment are functional +- You MAY run integration tests if your feature requires new tests to be added +- You MUST comment on the github issue if the tests fail, and use the handoff_to_user tool to get feedback on how to continue. +- You MUST check the current branch using `git branch --show-current` +- You MUST create a new feature branch if currently on main branch: + - You MUST use `git checkout -b ` to create and switch to a new feature branch + - You SHOULD use the BRANCH_NAME pattern `agent-tasks/{ISSUE_NUMBER}` unless this branch already exists + - You MUST make note of the newly created branch name + - You MUST use `git push origin ` to create the feature branch in remote + - If the push operation is deferred, continue with the workflow and note the deferred status +- You MAY continue on the current branch if not on main branch + + +### 2. Explore Phase + +### 2.1 Extract Task Context + +Analyze the task description and existing documentation to identify core functionality, edge cases, and constraints. + +**Constraints:** +- You MUST read the issue description +- You MUST investigate any links provided in the feature request + - You MUST note how the information from this link can influence the implementation +- You must review any implementation documentation provided by the repository: + - `AGENTS.md` + - `DEVELOPMENT.md` + - `CONTRIBUTING.md` + - `README.md` +- You MAY read existing comments, but focus mostly on the description +- You MUST capture issue metadata (title, labels, status, etc.) + +#### 2.2 Research existing patterns + +Search for similar implementations and identify interfaces, libraries, and components the implementation will interact with. + +**Constraints:** +- You MUST analyze the task and identify core functionality, edge cases, and constraints +- You MUST search the repository for relevant code, patterns, and information related to the coding task and note your findings +- You MUST create a dependency map showing how new code will integrate +- You MUST record the identified implementation paths in your notebook +- You SHOULD make note of any ambiguity you have in implementing the task + +#### 2.3 Create Code Context Document + +Compile all findings into a comprehensive code context notebook. + +**Constraints:** +- You MUST update your notebook with requirements, implementation details, patterns, and dependencies +- You MUST ensure your notes are well-structured with clear headings +- You MUST focus on high-level concepts and patterns rather than detailed implementation code +- You MUST NOT include complete code implementations in your notes because documentation should guide implementation, not provide it +- You MUST keep your notes concise and focused on guiding implementation rather than providing the implementation itself +- You SHOULD include a summary section and highlight areas of uncertainty +- You SHOULD use pseudocode or simplified representations when illustrating concepts +- You MAY include targeted code snippets when: + - Demonstrating usage of a specific library or API that's critical to the implementation + - Illustrating a complex pattern or technique that's difficult to describe in words alone + - Showing examples from existing codebase that demonstrate relevant patterns + - Providing reference implementations from official documentation +- You MUST clearly label any included code snippets as examples or references, not as the actual implementation +- You MUST keep any included code snippets brief and focused on the specific concept being illustrated + + +### 3. Plan Phase + +#### 3.1 Design Test Strategy + +Create a comprehensive list of test scenarios covering normal operation, edge cases, and error conditions. + +**Constraints:** +- You MUST check for existing testing strategies documented in the repository documentation or your notes +- You MUST cover all acceptance criteria with at least one test scenario +- You MUST define explicit input/output pairs for each test case +- You MUST make note of these test scenarios +- You MUST design tests that will initially fail when run against non-existent implementations +- You MUST NOT create mock implementations during the test design phase because tests should be written based solely on expected behavior, not influenced by implementation details +- You MUST focus on test scenarios and expected behaviors rather than detailed test code in documentation +- You MUST use high-level descriptions of test cases rather than complete test code snippets +- You MAY include targeted test code snippets when: + - Demonstrating a specific testing technique or pattern that's critical to understand + - Illustrating how to use a particular testing framework or library + - Showing examples of similar tests from the existing codebase +- You MUST clearly label any included test code snippets as examples or references +- You SHOULD explain the reasoning behind the proposed test structure + + +#### 3.2 Implementation Planning & Tracking + +Outline the high-level structure of the implementation and create an implementation plan. + +**Constraints:** +- You MUST create an implementation plan notebook +- You MUST include all key implementation tasks in the plan +- You SHOULD consider performance, security, and maintainability implications +- You MUST keep implementation planning notes concise and focused on architecture and patterns +- You MUST NOT include detailed code implementations in planning notes because planning should focus on architecture and approach, not specific code +- You MUST use high-level descriptions, UML diagrams, or simplified pseudocode rather than actual implementation code +- You MAY include targeted code snippets when: + - Illustrating a specific design pattern or architectural approach + - Demonstrating API usage that's central to the implementation + - Showing relevant examples from existing codebase or reference implementations + - Clarifying complex interactions between components +- You MUST clearly label any included code snippets as examples or references, not as the actual implementation +- You SHOULD make note of the reasoning behind the proposed implementation structure +- You MUST display the current checklist status after each major implementation step +- You MUST verify all checklist items are complete before finalizing the implementation +- You MUST maintain the implementation checklist in your progress notes using markdown checkbox format + +### 4. Code Phase + +#### 4.1 Implement Test Cases + +Write test cases based on the outlines, following strict TDD principles. + +**Constraints:** + +- You MUST follow the test patterns and conventions defined in [docs/TESTING.md](../../docs/TESTING.md) +- You MUST validate that the task environment is set up properly + - If you already created a commit, ensure the latest commit matches the expected hash + - If not, ensure the correct branch is checked out + - As a last resort, you MUST commit your current work to the current branch, then leave a comment on the Task issue or Pull Request for feedback on how to proceed +- You MUST save test implementations to the appropriate test directories in repo_root +- You MUST implement tests for ALL requirements before writing ANY implementation code +- You MUST follow the testing framework conventions used in the existing codebase + - You MUST follow test directory structure patterns + - You MUST follow test file format patterns: + - Follow class vs method test case creating patterns + - Follow mocking patterns + - Reuse existing test helper functions + - You MUST follow test creation rules if they are documented +- You MUST update the plan notes with test implementation details +- You MUST update the implementation checklist to mark test development as complete +- You MUST keep test notes concise and focused on test strategy rather than detailed test code +- You MUST execute tests after writing them to verify they fail as expected +- You MUST document the failure reasons in the TDD notes +- You MUST only seek user input if: + - Tests fail for unexpected reasons that you cannot resolve + - There are structural issues with the test framework + - You encounter environment issues that prevent test execution +- You MAY seek user input by commenting on the issue, and informing the user you are ready for their instruction by using the handoff_to_user tool +- You MUST otherwise continue automatically after verifying expected failures +- You MUST follow the Build Output Management practices defined in the Best Practices section + +#### 4.2 Develop Implementation Code + +Write implementation code to pass the tests, focusing on simplicity and correctness first. + +**Constraints:** +- You MUST update your progress in your implementation plan notes +- You MUST follow the strict TDD cycle: RED → GREEN → REFACTOR +- You MUST document each TDD cycle in your progress notes +- You MUST implement only what is needed to make the current test(s) pass +- You MUST follow the coding style and conventions of the existing codebase +- You MUST keep code comments concise and focused on key decisions rather than code details +- You MUST follow YAGNI, KISS, and SOLID principles +- You MAY make note of key implementation decisions including: + - Demonstrating usage of a specific library or API that's critical to the implementation + - Illustrating a complex pattern or technique that's difficult to describe in words alone + - Showing examples from existing codebase that demonstrate relevant patterns + - Explaining a particularly complex algorithm or data structure + - Providing reference implementations from official documentation +- You MUST make note of the reasoning behind implementation choices +- You SHOULD make note of any security considerations in the implementation +- You MUST execute tests after each implementation step to verify they now pass +- You MUST only seek user input if: + - Tests continue to fail after implementation for reasons you cannot resolve + - You encounter a design decision that cannot be inferred from requirements + - Multiple valid implementation approaches exist with significant trade-offs +- You MUST commit your work before seeing user feedback + - You MUST push your work if the `GITHUB_WRITE` environment variable is set to `true` +- You MAY seek user input by commenting on the issue, and informing the user you are ready for their instruction by using the handoff_to_user tool +- You MUST otherwise continue automatically after verifying test results +- You MUST follow the Build Output Management practices defined in the Best Practices section + +#### 4.3 Review and Refactor Implementation + +If the implementation is complete, proceed with a self-review of the implementation code to identify opportunities for simplification or improvement. + +**Constraints:** + +- You MUST check that all tasks are complete before proceeding + - if tests fail, you MUST identify the issue and implement a fix + - if builds fail, you MUST identify the issue implement a fix +- You MUST prioritize readability and maintainability over clever optimizations +- You MUST maintain test passing status throughout refactoring +- You SHOULD make note of simplification in your progress notes +- You SHOULD record significant refactorings in your progress notes +- You MUST return to step 4.2 if refactoring reveals additional implementation needs + +#### 4.4 Review and Refactor Tests + +After reviewing the implementation, review the test code to ensure it follows established patterns and provides adequate coverage. + +**Constraints:** + +- You MUST review your test code according to the guidelines in [docs/TESTING.md](../../docs/TESTING.md). +- You MUST verify tests conform to the testing documentation standards +- You MUST verify tests are readable and maintainable +- You SHOULD refactor tests that are overly complex or duplicative +- You MUST return to step 4.1 if tests need significant restructuring + +**Testing Checklist Verification (REQUIRED):** + +You MUST copy the checklist from [docs/TESTING.md](../../docs/TESTING.md) into your progress notes and explicitly verify each item. For each checklist item, you MUST: + +1. Copy the checklist item verbatim +2. Mark it as `[x]` (pass) or `[-]` (fail) +3. If failed, provide a brief explanation and fix the issue before proceeding + +Example format in your notes: + +```markdown +## Testing Checklist Verification + +- [x] Do the tests use relevant helpers from `__fixtures__` as noted in the "Test Fixtures Reference" section +- [ ] Are tests asserting on the entire object instead of specific fields? → FAILED: test on line 45 asserts individual properties, refactoring now +``` + +You MUST NOT proceed to step 4.5 until ALL checklist items pass. + +#### 4.5 Validate Implementation + +If the implementation meets all requirements and follows established patterns, proceed with this step. Otherwise, return to step 4.2 to fix any issues. + +**Constraints:** +- You MUST address any discrepancies between requirements and implementation +- You MUST execute the relevant test command and verify all implemented tests pass successfully +- You MUST execute the relevant build command and verify builds succeed +- You MUST ensure code coverage meets the requirements for the repository +- You MUST verify all items in the implementation plan have been completed +- You MUST provide the complete test execution output +- You MUST NOT claim implementation is complete if any tests are failing because failing tests indicate the implementation doesn't meet requirements + +**Build Validation:** +- You MUST run appropriate build commands based on the guidance in the repository +- You MUST verify that all dependencies are satisfied +- You MUST follow the Build Output Management practices defined in the Best Practices section + +#### 4.6 Respond to Review Feedback + +If you have received feedback from user reviews or PR comments, address them before proceeding to the commit phase. + +**Constraints:** + +- You MAY skip this step if no user feedback has been received yet +- You MUST reply to user review threads with a concise response + - You MUST keep your response to less than 3 sentences +- You MUST categorize each piece of feedback as: + - Actionable code changes that can be implemented immediately + - Clarifying questions that require user input + - Suggestions to consider for future iterations +- You MUST implement actionable code changes before proceeding +- You MUST re-run tests after addressing feedback to ensure nothing is broken +- You MUST return to step 4.3 after implementing changes to review the updated code +- You MUST use the handoff_to_user tool if clarification is needed before you can proceed + +### 5. Commit and Pull Request Phase + +If all tests are passing, draft a conventional commit message, perform the git commit, and create/update the pull request. + +**PR Checklist Verification (REQUIRED):** + +Before creating or updating a PR, you MUST copy the checklist from [docs/PR.md](../../docs/PR.md) into your progress notes and explicitly verify each item. For each checklist item, you MUST: + +1. Copy the checklist item verbatim +2. Mark it as `[x]` (pass) or `[-]` (fail) +3. If failed, revise the PR description until the item passes + +Example format in your notes: + +```markdown +## PR Description Checklist Verification + +- [x] Does the PR description target a Senior Engineer familiar with the project? +- [ ] Does the PR include a "Resolves #" in the body? → FAILED: missing issue reference, adding now +``` + +You MUST NOT create or update the PR until ALL checklist items pass. + +**Constraints:** + +- You MUST read and follow the PR description guidelines in [docs/PR.md](../../docs/PR.md) when creating pull requests & commits +- You MUST check that all tasks are complete before proceeding +- You MUST reference your notes for the issue you are creating a pull request for +- You MUST NOT commit changes until builds AND tests have been verified because committing broken code can disrupt the development workflow and introduce bugs into the codebase +- You MUST follow the Conventional Commits specification +- You MUST use `git status` to check which files have been modified +- You MUST use `git add` to stage all relevant files +- You MUST execute the `git commit -m ` command with the prepared commit message +- You MAY use `git push origin ` to push the local branch to the remote if the `GITHUB_WRITE` environment variable is set to `true` + - If the push operation is deferred, continue with PR creation and note the deferred status +- You MUST attempt to create the pull request using the `create_pull_request` tool if it does not exist yet + - If the PR creation is deferred, continue with the workflow and note the deferred status + - You MUST use the task id recorded in your notes, not the issue id +- If the `create_pull_request` tool fails (excluding deferred responses): + - The tool automatically handles fallback by posting a properly URL-encoded manual PR creation link as a comment on the specified fallback issue + - You MUST verify the fallback comment was posted successfully by checking the tool's return message + - You MUST NOT manually construct PR creation URLs since the tool handles URL encoding automatically +- If PR creation succeeds or is deferred: + - You MUST review your notes for any updates to provide on the pull request + - You MAY use the `update_pull_request` tool to update the pull request body or title + - If the update operation is deferred, continue with the workflow and note the deferred status +- You MUST use your notebook to record the new commit hash and PR status (created or link provided) + +### 6. Feedback Phase + +#### 6.1 Report Ready for Review + +Request the user for feedback on the implementation using the handoff_to_user tool. + +**Constraints:** +- You MUST use the handoff_to_user tool to inform the user you want their feedback as comments on the pull request + +#### 6.2. Read User Responses + +Retrieve and analyze the user's responses from the pull request reviews and comments. + +**Constraints:** +- You MUST make note of the pull request number +- You MUST fetch the review and the review comments from the PR using available tools + - You MUST use the list_pr_reviews to list all pr reviews + - You MUST use get_pr_review_comments to list the comments from the review + - You MUST use get_issue_comments to list the comments on the pull request + - You MAY filter the comments to only view the newly updated comments +- You MUST analyze each comment to determine if the request is clear and actionable +- You MUST categorize comments as: + - Clear actionable requests that can be implemented + - Unclear requests that need clarification + - General feedback that doesn't require code changes +- You MUST reply to unclear comments asking for specific clarification + - If comment posting is deferred, continue with the workflow and note the deferred status +- You MUST record your progress and update the implementation plan based on the feedback +- You MUST return to step 6.1 if you needed further clarification + +#### 6.3 Review Implementation Plan + +Based on the users feedback, you will review and update your implementation plan + +**Constraints:** +- You MUST make note of the requested changes from the user +- You MUST update your implementation plan based on the feedback from the user +- You MUST return to step 3 if you need to re-plan your implementation +- You MUST return to step 4 if you only need to make minor fixes +- You MUST NOT close the parent issue - only the user should close it after the pull request is merged +- You MUST not attempt to merge the pull request +- You MUST use the handoff_to_user tool to inform the user you are ready for clarifying information on the pull request +- You MUST include additional checklist items from [docs/PR.md](../../docs/PR.md) to validate the pull request description is correct after making additional changes + +## Desired Outcome + +* A complete, well-tested code implementation that meets the specified requirements +* A comprehensive test suite that validates the implementation +* Clean, documented code that: + * Follows existing package patterns and conventions + * Prioritizes readability and extensibility + * Avoids over-engineering and over-abstraction + * Is idiomatic and modern in the implementation language +* A well-organized set of implementation artifacts in the pull request description or comments +* Documentation or comments of key design decisions and implementation notes +* Properly committed changes with conventional commit messages + +## Examples + +## Troubleshooting + +### Branch Creation Issues +If feature branch creation fails: +- Move any changes in the `.github` directory to the `.github_temp` directory +- Check for existing branch with same name +- Generate alternative branch name with timestamp +- Ensure git repository is properly +- As a last resort, leave a comment on the Task Issue mentioning the issue you are facing + +### Pull Request Creation Issues +If PR creation fails (excluding deferred responses): +- Verify GitHub authentication and permissions +- Check if remote repository exists and is accessible +- You MUST commit your current work to the branch +- As a last resort, leave a comment on the Task Issue mentioning the issue you are facing + +### Deferred Operations +When GitHub tools or git operations are deferred: +- Continue with the workflow as if the operation succeeded +- Note the deferred status in your progress tracking +- The operations will be executed after agent completion +- Do not retry or attempt alternative approaches for deferred operations + +### Build Issues +If builds fail during implementation: +- You SHOULD follow build instructions from DEVELOPMENT.md if available +- You SHOULD verify you're in the correct directory for the build system +- You SHOULD try clean builds before rebuilding when encountering issues +- You SHOULD check for missing dependencies and resolve them +- You SHOULD restart build caches if connection issues occur + +## Best Practices + +### Repository-Specific Instructions +- Always check for DEVELOPMENT.md, AGENTS.md, and README.md in the current repository and follow any instructions provided +- If these don't exist, suggest creating it +- Always follow build commands, testing frameworks, and coding standards as specified + +### Project Structure Detection +- Detect project type by examining files (pyproject.toml, build.gradle, package.json, etc.) +- Check for DEVELOPMENT.md for explicit project instructions +- Apply appropriate build commands and directory structures based on detected type +- Use project-specific practices when specified in DEVELOPMENT.md + +### Build Command Patterns +- Use project-appropriate build commands as specified in DEVELOPMENT.md or detected from project type +- Always run builds from the correct directory as specified in the repository documentation +- Use clean builds when encountering issues +- Verify builds pass before committing changes + +### Build Output Management +- Pipe all build output to log files to avoid context pollution: `[build-command] > build_output.log 2>&1` +- Use targeted search patterns to verify build results instead of displaying full output +- Search for specific success/failure indicators based on build system +- Only display relevant excerpts from build logs when issues are detected +- You MUST not include build logs in your commit and pull request + +### Dependency Management +- Handle dependencies appropriately based on project type and DEVELOPMENT.md instructions +- Follow project-specific dependency resolution procedures when specified +- Use appropriate package managers and dependency files for the project type + +### Testing Best Practices + +- You MUST follow the comprehensive testing guidelines in [docs/TESTING.md](../../docs/TESTING.md) +- Follow TDD principles: RED → GREEN → REFACTOR +- Write tests that fail initially, then implement to make them pass +- Use appropriate testing frameworks for the project type or as specified in DEVELOPMENT.md +- Ensure test coverage meets the repository requirements +- Run tests after each implementation step + +### Documentation Organization +- Use consolidated documentation files: context.md, plan.md, progress.md +- Keep documentation separate from implementation code +- Focus on high-level concepts rather than detailed code in documentation +- Use progress tracking with markdown checklists +- Document decisions, assumptions, and challenges + +### Checklist Verification Pattern + +When documentation files contain checklists (e.g., `docs/TESTING.md`, `docs/PR.md`), you MUST: + +1. Copy the entire checklist into your progress notes +2. Explicitly verify each item by marking `[x]` or `[ ]` +3. For any failed items, document the issue and fix it before proceeding +4. Re-verify failed items after fixes until all pass + +This pattern ensures quality gates are not skipped and provides an audit trail of verification. + +### Pull Request Best Practices + +- You MUST follow the PR description guidelines in [docs/PR.md](../../docs/PR.md) +- Focus on WHY the change is needed, not HOW it's implemented +- Document public API changes with before/after code examples +- Write for senior engineers familiar with the project +- Skip implementation details, test coverage notes, and line-by-line change lists + +### Git Best Practices +- Commit early and often with descriptive messages +- Follow Conventional Commits specification +- You must create a new commit for each feedback iteration +- You must only push to your feature branch, never main diff --git a/.github/agent-sops/task-refiner.sop.md b/.github/agent-sops/task-refiner.sop.md new file mode 100644 index 000000000..a07c7887e --- /dev/null +++ b/.github/agent-sops/task-refiner.sop.md @@ -0,0 +1,298 @@ +# Task Refine SOP + +## Role + +You are a Task Refiner, and your goal is to review the feature request for a task and prepare it for implementation. This task feature request is defined as a github issue. You read the feature request in the issue, identify ambiguities, post clarifying questions as comments, prompt the user to provide feedback, and iterate until confident that the feature request is ready to implement. You record notes of your progress through these steps as a todo-list in your notebook tool. + +## Steps + +### 1. Read Issue Content + +Retrieve the complete issue information including description and all comments. + +**Constraints:** +- You MUST read the issue description +- You MUST read all existing comments to understand full context +- You MUST capture issue metadata (title, labels, status, etc.) + +### 2. Explore Phase +#### 2.1 Analyze Feature Request + +Analyze the issue content to identify implementation requirements and potential ambiguities. + +**Constraints:** +- You MUST check for existing documentation in: + - `AGENTS.md` + - `CONTRIBUTING.md` + - `README.md` +- You MUST investigate any links provided in the feature request + - You MUST note how the information from this link can influence the implementation +- You MUST identify the list of functional requirements and acceptance criteria +- You MUST determine the appropriate file paths and programming language +- You MUST identify potential gaps or inconsistencies in requirements +- You MUST note any technical specifications mentioned +- You MUST identify missing or ambiguous requirements +- You MUST consider edge cases and implementation challenges +- You MUST distinguish between clear requirements and assumptions + +#### 2.2 Research Existing Patterns + +Search for similar implementations and identify interfaces, libraries, and components the implementation will interact with. + +**Constraints:** +- You MUST identify the main programming languages and frameworks used +- You MUST search the current repository for relevant code, patterns, and information related to the task +- You MUST locate relevant existing code that relates to the feature request +- You MUST understand the current architecture and design patterns +- You MUST note any existing similar features or related functionality +- You MUST create a dependency map in your notes showing how the new feature will integrate +- You MUST note the identified implementation paths +- You SHOULD understand the build system and deployment process + +#### 2.3 Review Investigation + +After performing the investigation of the feature request and understanding the repository, you will think about the work needed to implement this feature. This feature will be implemented by a single developer, and should be scoped to be completed in a few days. You should note any concerns that this task is too large in scope + +**Constraints:** +- You MUST identify the work required to implement this feature +- You MUST review the current state of the repository, and identify any potential issues that might occur during implementation +- You MUST determine if this task is small enough to be implemented in a single Pull Request + - You should think if a single developer can implement this feature in about a week +- You MUST consider test implementation complexities as part of this feature request +- You MUST note if any github workflows are needed, or any changes to existing workflows are needed +- You MUST note any concerns in your notebook + +### 3 Clarification Phase + +### 3.1. Evaluate Completeness + +Deterime if you should ask clarifying questions, or if the task is already in an implementable state given your research. + +**Constraints:** +- You MAY skip to step 4 if you do not have any clarifying questions +- You SHOULD continue to the next step if you have identified questions to ask + +#### 3.2 Generate Clarifying Questions + +Create a numbered list of questions to resolve ambiguities and gather missing information. Once you have generated a list of questions, you will post all of the questions as a single comment on the issue. + +**Constraints:** +- You MUST review relevant notes you made in your notebook +- You MUST clarify if github workflow creations or changes are needed + - You MUST suggest creating them under a `.github_temp` directory since you do not have permission to push to `.github` directory +- You MAY ask about any ambiguous functionality +- You MAY clarify technical implementation details +- You MAY ask about user experience expectations +- You MAY ask for user input on edge cases that might not be obvious from the requirements +- You MAY ask clarify questions regarding information from provided links +- You MAY ask about non-functional requirements that might not be explicitly stated +- You SHOULD group related questions logically +- You MAY include questions about integration with existing systems +- You MAY ask the user if the issue should be broken down smaller issues + - You SHOULD provide justification for why it should be broken down + - You SHOULD suggest how the issue should be broken down into smaller feature requests +- You SHOULD ask about performance and scalability requirements +- You MUST create a comment with all of your questions on the issue. + - If the comment posting is deferred, continue with the workflow and note the deferred status +- You MUST wrap the comment body in a `
` element so it is collapsed by default + - Use a brief, descriptive summary (e.g., "Repository Analysis & Clarifying Questions") + - Place all detailed content inside the `
` block + +#### 3.3 Handoff to User for Response + +Use the handoff_to_user tool to inform the user they can reply to the clarifying questions on the issue. + +**Constraints:** +- You MUST use the handoff_to_user tool after posting your questions +- You MUST ask your clarifying questions when handing off to user +- You MUST tell the user to reply to your questions on the issue + +#### 3.4. Read User Responses + +Retrieve and analyze the user's responses from the issue comments. + +**Constraints:** +- You MUST read all new comments since the last check +- You MUST identify which comments contain responses to your questions +- You MUST extract answers and map them to the original questions +- You MUST handle cases where responses are incomplete or unclear +- You SHOULD take notes on how the repository can be updated (e.g. update AGENTS.md, CONTRIBUTING.md, README.md, etc) to clarify ambiguity in the future + +#### 3.5 (Optional) Break Down Task + +Determine from the users responses if the task should be broken down into sub-task. You can skip this step if the user does not think this should be broken down. + +**Constraints:** +- You MUST note any clarifying questions that are needed when breaking down this issue into a smaller task +- You MUST create a notebook for each new sub-issue you plan to create +- You MUST identify any dependencies that are required for the new sub-task +- You MUST determine the order of implementation for these new sub-task +- You MUST determine a name for each new task +- You MUST number the new sub-tasks based on their parent task number. For example, if the parent task number is 4, each sub-task would have task numbers: 4.1, 4.2, 4.3, ... + +#### 3.6 Re-Evaluate Completeness + +Determine if the responses provide sufficient information for implementation + +**Constraints:** +- You MUST assess if all critical questions have been answered +- You MUST identify any remaining ambiguities +- You MUST determine if additional clarification is needed +- You MUST be thorough in your assessment before proceeding +- You SHOULD consider the repository context in your evaluation +- You MUST make note of your decision +- You MAY continue to the next step if you have no more clarifying questions +- You SHOULD make note of your decision to continue +- You MAY return to step 2 if you need to do more research based on the answers the user provided +- You MAY return to step 3.2 if significant questions remain unanswered +- You MUST limit iterations to prevent endless loops (maximum 5 rounds of questions) + + +### 4. Update Task +#### 4.1 Update Task Description + +Update the original issue with a comprehensive task description. + +**Constraints:** +- You MUST edit the original issue description directly + - If the edit operation is deferred, continue with the workflow and note the deferred status +- You MUST preserve the original request context +- You MUST add a clear "Implementation Requirements" section +- You MUST include all clarified specifications +- You MUST document any assumptions made +- You MUST mention any ways to improve clarification in the repository going forward +- You SHOULD include acceptance criteria +- You MUST remove any github workflow requirements if they must be created under the `.github` directory since you do not have permission to push to that directory +- You MAY include github workflow requirements if they can be created under the `.github_temp` directory +- You MUST maintain professional formatting and clarity +- You SHOULD include implementation approach based on repository analysis +- You MAY include sub-tasks as requirements to the parent task description if there are any sub-tasks + +#### 4.2 (Optional) Create Sub-Issues + +Create new sub-tasks if you and the user have determined that this task is too complex + +**Constraints:** +- You MUST create new issue for each sub-task + - If issue creation is deferred, continue with the workflow and note the deferred status +- You MUST create a description with a comprehensive overview of the work required, following the same description format as the parent task +- You MUST add sub-task as sub-issues to the parent tasks issue using the `add_sub_issue` tool. + - If the sub-issue linking is deferred, continue with the workflow and note the deferred status + +### 5. Record Completion as Comment + +Record that the task review is complete and ready as a comment on the issue. + +**Constraints:** +- You MUST only add a comment on the parent issue if any sub-issues were created + - If comment posting is deferred, continue with the workflow and note the deferred status +- You MUST summarize what was accomplished in your comment +- You MUST confirm in your comment that the issue is ready for implementation, or explain why it is not +- You SHOULD mention any final recommendations or considerations +- You MUST wrap the comment body in a `
` element so it is collapsed by default + - Use a brief, descriptive summary (e.g., "Task Refinement Complete") + +## Examples + +### Example Repository Analysis Comment +```markdown +
+Repository Analysis & Clarifying Questions + +I've analyzed the repository structure and have some questions to ensure proper implementation: + +### Repository Context +- **Framework**: React with TypeScript frontend, Node.js/Express backend +- **Authentication**: Currently using JWT tokens (found in `/src/auth/`) +- **Database**: PostgreSQL with Prisma ORM +- **Existing Features**: Basic user registration exists in `/src/components/auth/` + +### Clarifying Questions + +#### Integration with Existing Auth System +1. Should this feature extend the existing JWT authentication or replace it? +2. How should this integrate with the current user registration flow? + +#### Database Schema +3. Should we modify the existing `users` table or create new tables? +4. What user data fields are required for this feature? + +#### Frontend Components +5. Should we update existing auth components or create new ones? +6. What should the user interface look like for this feature? + +Please respond when you have a chance. Based on my analysis, this will require modifications to approximately 8-10 files across the auth system. + +
+``` + +### Example Final Issue Description Update +```markdown +# Overview +Add user authentication system to allow users to log in and access protected features. + +## Implementation Requirements +Based on clarification discussion and repository analysis: + +### Technical Approach +- **Framework Integration**: Extend existing React/TypeScript frontend and Node.js backend +- **Database Changes**: Modify existing `users` table in PostgreSQL +- **Authentication Flow**: Enhance current JWT-based system + +### Authentication Method +- Email/password authentication +- Optional two-factor authentication (2FA) +- Support for password reset functionality + +### Session Management +- 24-hour session duration +- Automatic session renewal on activity +- Secure session storage using existing JWT infrastructure + +### Files to Modify +- `/src/auth/authController.js` - Add 2FA logic +- `/src/components/auth/LoginForm.tsx` - Update UI +- `/src/models/User.js` - Add 2FA fields +- `/prisma/schema.prisma` - Database schema updates +- `/src/middleware/auth.js` - Session management + +### Acceptance Criteria +- [ ] Users can register with email/password +- [ ] Users can log in and log out +- [ ] Sessions expire after 24 hours of inactivity +- [ ] Password reset functionality works +- [ ] 2FA can be enabled/disabled by user +- [ ] Integration tests pass +- [ ] Existing auth functionality remains intact +``` + +## Troubleshooting + +### Missing Issue: +If the issue does not exist: +1. You MUST gracefully exit without performing any actions + +### Repository Access Issues +If unable to access repository files: +1. Verify repository permissions and authentication +2. Check if the repository is private or has restricted access +3. Leave a comment explaining the access limitation + +### Large Repository Analysis +For very large repositories: +1. Focus on key directories related to the feature +2. Use search functionality to find relevant code patterns +3. Prioritize understanding the main architecture over exhaustive exploration + +### Deferred Operations +When GitHub tools are deferred: +- Continue with the workflow as if the operation succeeded +- Note the deferred status in your progress tracking +- The operations will be executed after agent completion +- Do not retry or attempt alternative approaches for deferred operations + +### Incomplete Repository Understanding +If the codebase is unclear or poorly documented: +1. Ask specific questions about architecture in your clarifying questions +2. Request documentation or guidance from the repository maintainers +3. Make reasonable assumptions and document them clearly diff --git a/.github/agent-sops/task-release-notes.sop.md b/.github/agent-sops/task-release-notes.sop.md new file mode 100644 index 000000000..5f024da82 --- /dev/null +++ b/.github/agent-sops/task-release-notes.sop.md @@ -0,0 +1,586 @@ +# Release Notes Generator SOP + +## Role + +You are a Release Notes Generator, and your goal is to create high-quality release notes highlighting Major Features and Major Bug Fixes for a software project. Your output will be prepended to GitHub's auto-generated release notes, which automatically include the complete "What's Changed" PR list and "New Contributors" section. + +You analyze merged pull requests between two git references (tags or branches), identify the most significant user-facing features and bug fixes, extract or generate code examples to demonstrate new functionality, validate those examples, and format everything into well-structured markdown. Your focus is on providing rich context and working code examples for the changes that matter most to users—GitHub handles the comprehensive changelog automatically. + +**Important**: You are executing in an ephemeral environment. Any files you create (test files, notes, etc.) will be discarded after execution. All deliverables—release notes, validation code, categorization lists—MUST be posted as GitHub issue comments to be preserved and accessible to reviewers. + +## Steps + +### 1. Setup and Input Processing + +#### 1.1 Accept Git References + +Parse the input to identify the two git references (tags or branches) to compare. + +**Constraints:** +- You MUST accept two git references as input (e.g., `v1.0.0` and `v1.1.0`, or `release/1.0` and `release/1.1`) +- You MUST validate that both references are provided +- You MUST track the base reference (older) and head reference (newer) for use throughout the workflow +- You SHOULD use semantic version tags when available (e.g., `v1.14.0`, `v1.15.0`) +- You MAY accept branch names if tags are not available + +#### 1.2 Check for Existing GitHub Release + +Check if a release (draft or non-draft) already exists with auto-generated PR information. + +**Constraints:** +- You MUST first check if a release exists for the target version using the GitHub API: `GET /repos/:owner/:repo/releases` +- You MUST check if the release body contains GitHub's auto-generated "What's Changed" section +- If a release with PR list exists: + - You MUST parse the PR list from the existing release body + - You MUST extract PR numbers, titles, authors, and links from the markdown + - You SHOULD skip Step 1.3 (Query GitHub API for PRs) since the PR list is already available +- If no release exists or it lacks PR information: + - You MUST proceed to Step 1.3 to query for PRs manually +- You SHOULD note in the categorization comment whether you used existing release data or queried manually + +#### 1.3 Query GitHub API for PRs (if needed) + +Retrieve merged pull requests between the two git references when no release exists. + +**Constraints:** +- You SHOULD skip this step if PR information was obtained from an existing release in Step 1.2 +- You MUST query the GitHub API to get commits between the two references: `GET /repos/:owner/:repo/compare/:base...:head` +- You MUST extract the list of merged pull requests from the commit history +- You MUST retrieve the full list even if there are many PRs (handle pagination) +- You SHOULD track the total number of PRs found for reporting in the categorization comment +- You MAY need to filter for only merged PRs if the comparison includes unmerged commits + +#### 1.4 Retrieve PR Metadata + +For each PR identified (from release or API query), fetch additional metadata needed for categorization. + +**Constraints:** +- If PR information came from a release, you already have: + - PR number and title + - Author username + - Link to the PR +- You MUST retrieve additional metadata for PRs being considered for Major Features or Major Bug Fixes: + - PR description/body (essential for understanding the change) + - PR labels (if any) +- You SHOULD retrieve for Major Feature candidates: + - Files changed in the PR (to find code examples) +- You MAY retrieve: + - PR review comments if helpful for understanding the change +- You SHOULD minimize API calls by only fetching detailed metadata for PRs that appear significant based on title/prefix +- You MUST track this data for use in categorization and release notes generation + +### 2. PR Analysis and Categorization + +#### 2.1 Analyze PR Titles and Prefixes + +Extract categorization signals from PR titles using conventional commit prefixes. + +**Constraints:** +- You MUST check each PR title for conventional commit prefixes: + - `feat:` or `feature:` - Feature additions + - `fix:` - Bug fixes + - `refactor:` - Code refactoring + - `docs:` - Documentation changes + - `test:` - Test additions/changes + - `chore:` - Maintenance tasks + - `ci:` - CI/CD changes + - `perf:` - Performance improvements +- You MUST use these prefixes as initial categorization signals +- You SHOULD record the prefix-based category for each PR +- You MAY encounter PRs without conventional commit prefixes + +#### 2.2 Analyze PR Descriptions + +Use LLM analysis to understand the significance and user impact of each change. + +**Constraints:** +- You MUST read and analyze the PR description for each PR +- You MUST assess the user-facing impact of the change: + - Does it introduce new functionality users will interact with? + - Does it fix a bug that users experienced? + - Is it purely internal with no user-visible changes? +- You MUST identify if the change introduces breaking changes +- You SHOULD identify if the PR includes code examples in its description +- You SHOULD note any links to documentation or related issues +- You MAY consider the size and complexity of the change + +#### 2.3 Categorize PRs + +Combine prefix analysis and LLM analysis to categorize each PR appropriately. + +**Constraints:** +- You MUST categorize each PR into one of these categories: + - **Major Features**: Significant new functionality or enhancements that users should know about + - New APIs, methods, or classes + - New capabilities or workflows + - Significant feature enhancements + - User-facing changes with clear value + - **Major Bug Fixes**: Critical bug fixes that impact user experience + - Fixes for broken functionality + - Security fixes + - Data corruption fixes + - Performance issue resolutions + - **Minor Changes**: Everything else + - Internal refactoring without user-visible changes + - Documentation-only changes + - Test-only changes + - Minor fixes or typos + - Dependency updates without feature impact + - CI/CD changes + - Code style changes +- You MUST prioritize user impact over technical classification +- You MUST use BOTH prefix signals AND description analysis to make the final decision +- You SHOULD be conservative - when in doubt, classify as "Minor Changes" +- You SHOULD limit "Major Features" to approximately 3-8 items per release +- You SHOULD limit "Major Bug Fixes" to approximately 0-5 items per release +- You MUST record your categorization decisions (these will be posted as a GitHub comment in Step 2.4) + +#### 2.4 Confirm Categorization with User + +Present the categorized PRs to the user for review and confirmation. + +**Constraints:** +- You MUST present the categorization to the user for review before proceeding +- You MUST format the categorization as a numbered list organized by category: + - **Major Features** (with PR numbers and titles) + - **Major Bug Fixes** (with PR numbers and titles) + - **Minor Changes** (with PR numbers and titles, or just count if >20) +- You MUST make it easy for the user to recategorize items by providing clear instructions +- You SHOULD present the list in a format that allows easy reordering (e.g., "To move PR#123 to Major Features, reply with: 'Move #123 to Major Features'") +- You MUST post this categorization as a comment on the GitHub issue +- You MUST use the handoff_to_user tool to request review +- You MUST wait for user confirmation or recategorization before proceeding +- You SHOULD update your categorization based on user feedback +- You MAY iterate on categorization if the user requests changes + +### 3. Code Snippet Extraction and Generation + +**Note**: This phase applies only to PRs categorized as "Major Features". Bug fixes typically do not require code examples. + +#### 3.1 Search for Existing Code Examples + +Search merged PRs for existing code that demonstrates the new feature. + +**Constraints:** +- You MUST search each Major Feature PR for existing code examples in: + - Test files (especially integration tests or example tests) + - Example applications or scripts in `examples/` directory + - Code snippets in the PR description + - Documentation updates that include code examples + - README updates with usage examples +- You MUST prioritize test files that show real usage of the feature +- You SHOULD look for the simplest, most focused examples +- You SHOULD prefer examples that are already validated (from test files) +- You MAY examine multiple PRs if a feature spans several PRs + +#### 3.2 Extract Code from PRs + +When suitable examples are found, extract them for use in release notes. + +**Constraints:** +- You MUST extract the most relevant and focused code snippet +- You MUST simplify extracted code for release notes: + - Remove unnecessary imports + - Remove test scaffolding and setup code + - Remove assertions and test-specific code + - Keep only the core usage demonstration +- You MUST ensure extracted code is syntactically complete (balanced braces, valid syntax) +- You SHOULD keep examples under 20 lines when possible +- You SHOULD focus on the "happy path" usage +- You MAY need to extract from multiple locations and combine them + +#### 3.3 Generate New Snippets When Needed + +When existing examples are insufficient, generate new code snippets. + +**Constraints:** +- You MUST generate new snippets when: + - No suitable examples exist in the PR + - Existing code is too complex or specific + - Existing code doesn't clearly demonstrate the feature +- You MUST keep generated snippets minimal and focused +- You MUST use the appropriate programming language for the project +- You MUST ensure generated code follows the project's coding patterns +- You SHOULD base generated code on the actual API changes in the PR +- You SHOULD include only necessary imports +- You SHOULD demonstrate the most common use case +- You MAY include brief inline comments to clarify usage + +### 4. Code Validation + +**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Validation must occur AFTER snippets have been extracted or generated in Step 3. + +#### 4.1 Create Temporary Test Files + +Create temporary test files to validate the code snippets. + +**Constraints:** +- You MUST create a temporary test file for each code snippet +- You MUST place test files in an appropriate test directory based on the project structure +- You MUST include all necessary imports and setup code in the test file +- You MUST wrap the snippet in a proper test case +- You SHOULD use the project's testing framework +- You MAY need to mock dependencies or setup test fixtures +- You MAY include additional test code that doesn't appear in the release notes + +**Example test file structure** (language-specific format will vary): +``` +# Test structure depends on the project's testing framework +# Include necessary imports, setup, and the snippet being validated +# Add assertions to verify the code works correctly +``` + +#### 4.2 Run Validation Tests + +Execute tests to ensure code snippets are valid and functional. + +**Constraints:** +- You MUST run the appropriate test command for the project (e.g., `npm test`, `pytest`, `go test`) +- You MUST verify that the test passes successfully +- You MUST check that the code compiles without errors in compiled languages +- You SHOULD run type checking if applicable (e.g., `npm run type-check`, `mypy`) +- You MAY need to adjust imports or setup code if tests fail +- You MAY need to install additional dependencies if required + +**Fallback validation** (if test execution fails or is not possible): +- You MUST at minimum validate syntax using the appropriate language tools +- You MUST ensure the code is syntactically correct +- You MUST verify all referenced types and modules exist + +#### 4.3 Handle Validation Failures + +Address any validation failures before including snippets in release notes. + +**Constraints:** +- You MUST NOT include unvalidated code snippets in release notes +- You MUST revise the code snippet if validation fails +- You MUST re-run validation after making changes +- You SHOULD examine the actual implementation in the PR if generated code fails +- You SHOULD simplify the example if complexity is causing validation issues +- You MAY extract a different example from the PR if the current one cannot be validated +- You MAY seek clarification if you cannot create a valid example +- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.2) +- You MAY delete temporary test files after capturing their content, as the environment is ephemeral + +### 5. Release Notes Formatting + +#### 5.1 Format Major Features Section + +Create the Major Features section with concise descriptions and code examples. + +**Constraints:** +- You MUST create a section with heading: `## Major Features` +- You MUST create a subsection for each major feature using heading: `### Feature Name - [PR#123](link)` +- You MUST include the PR number and link in the feature heading +- You MUST write a concise description of 2-3 sentences that explains what the feature does and why it matters +- You MUST NOT use bullet points or lists in feature descriptions—use prose only +- You MUST NOT write lengthy multi-paragraph explanations +- You MUST include a code block demonstrating the feature using the project's programming language +- You MUST use proper syntax highlighting for the project's language +- You SHOULD keep code examples under 20 lines +- You SHOULD include inline comments in code examples only when necessary for clarity +- You MAY include multiple code examples if the feature has distinct use cases +- You MAY include a single closing sentence after the code example (e.g., documentation link or brief note) +- You MAY reference multiple PRs if a feature spans several PRs: `### Feature Name - [PR#123](link), [PR#124](link)` + +**Example format**: +```markdown +### Structured Output via Agentic Loop - [PR#943](https://github.com/org/repo/pull/943) + +Agents can now validate responses against predefined schemas with configurable retry behavior for non-conforming outputs. + +\`\`\`[language] +# Code example in the project's programming language +# Show the feature in action with clear, focused code +\`\`\` + +See the [Structured Output docs](https://docs.example.com/structured-output) for configuration options. +``` + +#### 5.2 Format Major Bug Fixes Section + +Create the Major Bug Fixes section highlighting critical fixes (if any exist). + +**Constraints:** +- You MUST create this section only if there are critical bug fixes +- You MUST create a section with heading: `## Major Bug Fixes` +- You MUST add a horizontal rule before this section: `---` +- You MUST format each bug fix as a bullet list item: `- **Fix Title** - [PR#123](link)` +- You MUST write a brief explanation (1-2 sentences) after each bullet that describes: + - What was broken + - What impact it had on users + - What is now fixed +- You SHOULD order fixes by severity or user impact +- You SHOULD keep descriptions concise but informative +- You MAY skip this section entirely if there are no major bug fixes + +**Example format**: +```markdown +--- + +## Major Bug Fixes + +- **Guardrails Redaction Fix** - [PR#1072](https://github.com/org/repo/pull/1072) + Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. + +- **Tool Result Block Redaction** - [PR#1080](https://github.com/org/repo/pull/1080) + Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. +``` + +#### 5.3 End with Separator + +Add a horizontal rule to separate your content from GitHub's auto-generated sections. + +**Constraints:** +- You MUST end your release notes with a horizontal rule: `---` +- This visually separates your curated content from GitHub's auto-generated "What's Changed" and "New Contributors" sections +- You MUST NOT include a "Full Changelog" link—GitHub adds this automatically + +**Example format**: +```markdown +## Major Bug Fixes + +- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) + Description of what was fixed. + +--- +``` + +### 6. Output Delivery + +**Critical**: You are running in an ephemeral environment. All files created during execution (test files, temporary notes, etc.) will be deleted when the workflow completes. You MUST post all deliverables as GitHub issue comments—this is the only way to preserve your work and make it accessible to reviewers. + +**Comment Structure**: Post exactly two comments on the GitHub issue: +1. **Validation Comment** (first): Contains all validation code for all features in one batched comment +2. **Release Notes Comment** (second): Contains the final formatted release notes + +This ordering allows reviewers to see the validation evidence before reviewing the release notes. + +#### 6.1 Post Validation Code Comment + +Batch all validation code into a single GitHub issue comment. + +**Constraints:** +- You MUST post ONE comment containing ALL validation code for ALL features +- You MUST NOT post separate comments for each feature's validation +- You MUST post this comment BEFORE the release notes comment +- You MUST include all test files created during validation (Step 4) in this single comment +- You MUST NOT reference local file paths—the ephemeral environment will be destroyed +- You MUST clearly label this comment as "Code Validation Tests" +- You MUST include a note explaining that this code was used to validate the snippets in the release notes +- You SHOULD use collapsible `
` sections to organize validation code by feature: + ```markdown + ## Code Validation Tests + + The following test code was used to validate the code examples in the release notes. + +
+ Validation: Feature Name 1 + + \`\`\`typescript + [Full test file for feature 1] + \`\`\` + +
+ +
+ Validation: Feature Name 2 + + \`\`\`typescript + [Full test file for feature 2] + \`\`\` + +
+ ``` +- This allows reviewers to copy and run the validation code themselves + +#### 6.2 Post Release Notes Comment + +Post the formatted release notes as a single GitHub issue comment. + +**Constraints:** +- You MUST post ONE comment containing the complete release notes +- You MUST post this comment AFTER the validation comment +- You MUST use the `add_issue_comment` tool to post the comment +- You MUST include Major Features, Major Bug Fixes (if any), and a trailing separator (`---`) +- You MUST NOT expect users to access any local files—everything must be in the comment +- You SHOULD add a brief introductory line (e.g., "## Release Notes for v1.15.0") +- You MAY use markdown formatting in the comment +- If comment posting is deferred, continue with the workflow and note the deferred status + +## Examples + +### Example 1: Major Features Section with Code + +```markdown +## Major Features + +### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) + +MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. + +\`\`\`[language] +# Code example in the project's programming language +# Demonstrate the key feature usage +# Keep it focused and concise +\`\`\` + +See the [MCP docs](https://docs.example.com/mcp) for details. + +### Async Streaming for Multi-Agent Systems - [PR#961](https://github.com/org/repo/pull/961) + +Multi-agent systems now support async streaming, enabling real-time event streaming from agent teams as they collaborate. + +\`\`\`[language] +# Another code example +# Show the feature in action +# Include only essential code +\`\`\` +``` + +### Example 2: Major Bug Fixes Section + +```markdown +--- + +## Major Bug Fixes + +- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) + Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. + +- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) + Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. + +- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) + Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. +``` + +### Example 3: Complete Release Notes Structure + +```markdown +## Major Features + +### Feature Name - [PR#123](https://github.com/owner/repo/pull/123) + +Description of the feature and its impact. + +\`\`\`[language] +# Code example demonstrating the feature +\`\`\` + +--- + +## Major Bug Fixes + +- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) + Description of what was fixed and why it matters. + +--- +``` + +Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. + +### Example 4: Issue Comment with Release Notes + +```markdown +Release notes for v1.15.0: + +## Major Features + +### Managed MCP Connections - [PR#895](https://github.com/strands-agents/sdk-typescript/pull/895) + +We've introduced MCP Connections via ToolProviders... + +[... rest of release notes ...] + +--- +``` + +When this content is added to the GitHub release, GitHub will automatically append the "What's Changed" and "New Contributors" sections below the separator. + +## Troubleshooting + +### Missing or Invalid Git References + +If one or both git references are missing or invalid: +1. Verify the references exist in the repository using `git ls-remote --tags` or `git ls-remote --heads` +2. Check if the user provided branch names vs. tag names +3. Leave a comment on the issue explaining which reference is invalid +4. Use the handoff_to_user tool to request clarification + +### GitHub API Rate Limiting + +If you encounter GitHub API rate limit errors: +1. Check the rate limit status using the `X-RateLimit-Remaining` header +2. If rate limited, note the `X-RateLimit-Reset` timestamp +3. Consider reducing the number of API calls by batching requests +4. Leave a comment on the issue explaining the rate limit issue +5. Use the handoff_to_user tool to inform the user + +### Code Validation Failures + +If code validation fails for a snippet: +1. Review the test output to understand the failure reason +2. Check if the feature requires additional dependencies or setup +3. Examine the actual implementation in the PR to understand correct usage +4. Try simplifying the example to focus on core functionality +5. Consider using a different example from the PR +6. If unable to validate, note the issue in the release notes comment and skip the code example for that feature +7. Leave a comment on the issue noting which features couldn't include validated code examples + +### Large PR Sets (>100 PRs) + +If there are many PRs between the references: +1. Consider whether the git references are correct (e.g., not comparing main to an ancient tag) +2. Focus categorization efforts on the most significant changes +3. Be more selective about what qualifies as a "Major Feature" or "Major Bug Fix" + +### No PRs Found Between References + +If no PRs are found: +1. Verify that the base and head references are in the correct order (base should be older) +2. Check if the references are the same +3. Verify that there are actually commits between the references +4. Check if a release exists that might have the PR list +5. Leave a comment on the issue explaining the situation +6. Use the handoff_to_user tool to request clarification + +### Release Parsing Issues + +If the release body cannot be parsed correctly: +1. Check if the format matches GitHub's standard auto-generated format +2. Look for the "What's Changed" heading and bullet list format: `* PR title by @author in URL` +3. If parsing fails, fall back to querying the GitHub API directly (Step 1.3) +4. Note in the categorization comment that you fell back to API queries + +### Deferred Operations + +When GitHub tools or git operations are deferred (GITHUB_WRITE=false): +- Continue with the workflow as if the operation succeeded +- Note the deferred status in your progress tracking +- The operations will be executed after agent completion +- Do not retry or attempt alternative approaches for deferred operations + +### Unable to Extract Suitable Code Examples + +If no suitable code examples can be found or generated for a feature: +1. Examine the PR description more carefully for usage information +2. Look at related documentation changes +3. Consider whether the feature actually needs a code example (some features are self-explanatory) +4. Generate a minimal example based on the API changes, even if you can't fully validate it +5. Mark the example as "conceptual" if validation isn't possible +6. Consider omitting the code example if it would be misleading + +## Desired Outcome + +* Focused release notes highlighting Major Features and Major Bug Fixes with concise descriptions (2-3 sentences, no bullet points) +* Working, validated code examples for all major features +* Well-formatted markdown that renders properly on GitHub +* Release notes posted as a comment on the GitHub issue for review + +**Important**: Your generated release notes will be prepended to GitHub's auto-generated release notes. GitHub automatically generates: +- "What's Changed" section listing all PRs with authors and links +- "New Contributors" section acknowledging first-time contributors +- "Full Changelog" comparison link + +You should NOT include these sections—focus exclusively on Major Features and Major Bug Fixes that benefit from detailed descriptions and code examples. Minor changes (refactors, docs, tests, chores, etc.) will be covered by GitHub's automatic changelog. \ No newline at end of file diff --git a/.github/scripts/javascript/process-input.cjs b/.github/scripts/javascript/process-input.cjs new file mode 100644 index 000000000..b7ed29263 --- /dev/null +++ b/.github/scripts/javascript/process-input.cjs @@ -0,0 +1,125 @@ +// This file assumes that its run from an environment that already has github and core imported: +// const github = require('@actions/github'); +// const core = require('@actions/core'); + +const fs = require('fs'); + +async function getIssueInfo(github, context, inputs) { + const issueId = context.eventName === 'workflow_dispatch' + ? inputs.issue_id + : context.payload.issue.number.toString(); + const command = context.eventName === 'workflow_dispatch' + ? inputs.command + : (context.payload.comment.body.match(/^\/strands\s*(.*?)$/m)?.[1]?.trim() || ''); + + console.log(`Event: ${context.eventName}, Issue ID: ${issueId}, Command: "${command}"`); + + const issue = await github.rest.issues.get({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: issueId + }); + + return { issueId, command, issue }; +} + +async function determineBranch(github, context, issueId, mode, isPullRequest) { + let branchName = 'main'; + + if (mode === 'implementer' && !isPullRequest) { + branchName = `agent-tasks/${issueId}`; + + const mainRef = await github.rest.git.getRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: 'heads/main' + }); + + try { + await github.rest.git.createRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: `refs/heads/${branchName}`, + sha: mainRef.data.object.sha + }); + console.log(`Created branch ${branchName}`); + } catch (error) { + if (error.status === 422 || error.message?.includes('already exists')) { + console.log(`Branch ${branchName} already exists`); + } else { + throw error; + } + } + } else if (isPullRequest) { + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: issueId + }); + branchName = pr.data.head.ref; + } + + return branchName; +} + +function buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs) { + const sessionId = inputs.session_id || (mode === 'implementer' + ? `${mode}-${branchName}`.replace(/[\/\\]/g, '-') + : `${mode}-${issueId}`); + + const scriptFiles = { + 'implementer': '.github/agent-sops/task-implementer.sop.md', + 'refiner': '.github/agent-sops/task-refiner.sop.md', + 'release-notes': '.github/agent-sops/task-release-notes.sop.md' + }; + + const scriptFile = scriptFiles[mode] || scriptFiles['refiner']; + const systemPrompt = fs.readFileSync(scriptFile, 'utf8'); + + let prompt = (isPullRequest) + ? 'The pull request id is:' + : 'The issue id is:'; + prompt += `${issueId}\n${command}\nreview and continue`; + + return { sessionId, systemPrompt, prompt }; +} + +module.exports = async (context, github, core, inputs) => { + try { + const { issueId, command, issue } = await getIssueInfo(github, context, inputs); + + const isPullRequest = !!issue.data.pull_request; + + // Determine mode based on explicit command first, then context + let mode; + if (command.startsWith('release-notes') || command.startsWith('release notes')) { + mode = 'release-notes'; + } else if (command.startsWith('implement')) { + mode = 'implementer'; + } else if (command.startsWith('refine')) { + mode = 'refiner'; + } else { + // Default behavior when no explicit command: PR -> implementer, Issue -> refiner + mode = isPullRequest ? 'implementer' : 'refiner'; + } + console.log(`Is PR: ${isPullRequest}, Command: "${command}", Mode: ${mode}`); + + const branchName = await determineBranch(github, context, issueId, mode, isPullRequest); + console.log(`Building prompts - mode: ${mode}, issue: ${issueId}, is PR: ${isPullRequest}`); + + const { sessionId, systemPrompt, prompt } = buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs); + + console.log(`Session ID: ${sessionId}`); + console.log(`Task prompt: "${prompt}"`); + + core.setOutput('branch_name', branchName); + core.setOutput('session_id', sessionId); + core.setOutput('system_prompt', systemPrompt); + core.setOutput('prompt', prompt); + + } catch (error) { + const errorMsg = `Failed: ${error.message}`; + console.error(errorMsg); + core.setFailed(errorMsg); + } +}; diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py new file mode 100644 index 000000000..db10ceadb --- /dev/null +++ b/.github/scripts/python/agent_runner.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Strands GitHub Agent Runner +A portable agent runner for use in GitHub Actions across different repositories. +""" + +import json +import os +import sys +from typing import Any + +from strands import Agent +from strands.agent.conversation_manager import SlidingWindowConversationManager +from strands.session import S3SessionManager +from strands.models.bedrock import BedrockModel +from botocore.config import Config + +from strands_tools import http_request, shell + +# Import local GitHub tools we need +from github_tools import ( + add_issue_comment, + create_issue, + create_pull_request, + get_issue, + get_issue_comments, + get_pull_request, + get_pr_review_and_comments, + list_issues, + list_pull_requests, + reply_to_review_comment, + update_issue, + update_pull_request, +) + +# Import local tools we need +from handoff_to_user import handoff_to_user +from notebook import notebook +from str_replace_based_edit_tool import str_replace_based_edit_tool + +# Strands configuration constants +STRANDS_MODEL_ID = "global.anthropic.claude-sonnet-4-5-20250929-v1:0" +STRANDS_MAX_TOKENS = 64000 +STRANDS_BUDGET_TOKENS = 8000 +STRANDS_REGION = "us-west-2" + +# Default values for environment variables used only in this file +DEFAULT_SYSTEM_PROMPT = "You are an autonomous GitHub agent powered by Strands Agents SDK." + +def _get_all_tools() -> list[Any]: + return [ + # File editing + str_replace_based_edit_tool, + + # System tools + shell, + http_request, + + # GitHub issue tools + create_issue, + get_issue, + update_issue, + list_issues, + add_issue_comment, + get_issue_comments, + + # GitHub PR tools + create_pull_request, + get_pull_request, + update_pull_request, + list_pull_requests, + get_pr_review_and_comments, + reply_to_review_comment, + + # Agent tools + notebook, + handoff_to_user, + ] + + +def run_agent(query: str): + """Run the agent with the provided query.""" + try: + # Get tools and create model + tools = _get_all_tools() + + # Create Bedrock model with inlined configuration + additional_request_fields = {} + additional_request_fields["anthropic_beta"] = ["interleaved-thinking-2025-05-14"] + + additional_request_fields["thinking"] = { + "type": "enabled", + "budget_tokens": STRANDS_BUDGET_TOKENS + } + + model = BedrockModel( + model_id=STRANDS_MODEL_ID, + max_tokens=STRANDS_MAX_TOKENS, + region_name=STRANDS_REGION, + boto_client_config=Config( + read_timeout=900, + connect_timeout=900, + retries={"max_attempts": 3, "mode": "adaptive"}, + ), + additional_request_fields=additional_request_fields, + cache_prompt="default", + cache_tools="default", + ) + system_prompt = os.getenv("INPUT_SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) + session_id = os.getenv("SESSION_ID") + s3_bucket = os.getenv("S3_SESSION_BUCKET") + s3_prefix = os.getenv("GITHUB_REPOSITORY", "") + + if s3_bucket and session_id: + print(f"🤖 Using session manager with session ID: {session_id}") + session_manager = S3SessionManager( + session_id=session_id, + bucket=s3_bucket, + prefix=s3_prefix, + ) + else: + raise ValueError("Both SESSION_ID and S3_SESSION_BUCKET must be set") + + # Create agent + agent = Agent( + model=model, + system_prompt=system_prompt, + tools=tools, + session_manager=session_manager, + ) + + print("Processing user query...") + result = agent(query) + + print(f"\n\nAgent Result 🤖\nStop Reason: {result.stop_reason}\nMessage: {json.dumps(result.message, indent=2)}") + except Exception as e: + error_msg = f"❌ Agent execution failed: {e}" + print(error_msg) + raise e + + +def main() -> None: + """Main entry point for the agent runner.""" + try: + # Read task from command line arguments + if len(sys.argv) < 2: + raise ValueError("Task argument is required") + + task = " ".join(sys.argv[1:]) + if not task.strip(): + raise ValueError("Task cannot be empty") + print(f"🤖 Running agent with task: {task}") + + run_agent(task) + + except Exception as e: + error_msg = f"Fatal error: {e}" + print(error_msg) + + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/.github/scripts/python/github_tools.py b/.github/scripts/python/github_tools.py new file mode 100644 index 000000000..8826b4611 --- /dev/null +++ b/.github/scripts/python/github_tools.py @@ -0,0 +1,843 @@ +"""GitHub repository management tool for Strands Agents. + +This module provides comprehensive GitHub repository operations including issues, +pull requests, comments, and repository management. Supports full GitHub API +integration with rich console output and error handling. + +Key Features: +1. List and manage issues and pull requests +2. Add comments to issues and PRs +3. Create, update, and manage issues +4. Create, update, and manage pull requests +5. Get detailed information for specific issues/PRs +6. Manage PR reviews and review comments +7. Get issue and PR comment threads +8. Check GitHub token permissions for repositories +9. Rich console output with formatted tables +10. Automatic fallback to GITHUB_REPOSITORY environment variable + +Usage Examples: +```python +from strands import Agent +from tools.github_tools import list_issues, add_comment, create_issue, _check_token_permissions + +agent = Agent(tools=[list_issues, add_comment, create_issue]) + +# Check token permissions +has_write = _check_token_permissions("ghp_token123", "owner/repo") + +# List open issues in repository +result = agent.tool.list_issues(state="open", repo="owner/repo") + +# Add comment to an issue +result = agent.tool.add_comment( + issue_number=42, + comment_text="Great idea! I'll work on this.", + repo="owner/repo" +) + +# Create a new issue +result = agent.tool.create_issue( + title="Bug: Application crashes on startup", + body="Description of the issue with steps to reproduce...", + repo="owner/repo" +) + +# List pull requests +result = agent.tool.list_pull_requests(state="open", repo="owner/repo") + +# Get specific issue details +result = agent.tool.get_issue(issue_number=123, repo="owner/repo") + +# Update pull request +result = agent.tool.update_pull_request( + pr_number=456, + title="Updated PR title", + body="Updated description", + repo="owner/repo" +) +``` +""" + +import os +import traceback +from datetime import datetime +from functools import wraps +import json +from typing import Any, TypedDict +from urllib.parse import urlencode, quote + +import requests +from rich import box +from rich.markup import escape +from rich.panel import Panel +from rich.table import Table +from strands import tool +from strands_tools.utils import console_util + +console = console_util.create() + + +class GitHubOperation(TypedDict): + """Type definition for GitHub operation records in JSONL files.""" + timestamp: str + function: str + args: list[Any] + kwargs: dict[str, Any] + + +def log_inputs(func): + """Decorator to log function inputs in a blue panel.""" + @wraps(func) + def wrapper(*args, **kwargs): + # Get function name and format it nicely + func_name = func.__name__.replace('_', ' ').title() + + # Format parameters + params = [] + for k, v in kwargs.items(): + if isinstance(v, str) and len(v) > 50: + params.append(f"{k}='{v[:50]}...'") + else: + params.append(f"{k}='{v}'") + + console.print(Panel(", ".join(params), title=f"[bold blue]{func_name}", border_style="blue")) + return func(*args, **kwargs) + return wrapper + + +def _github_request( + method: str, endpoint: str, repo: str | None = None, data: dict | None = None, params: dict | None = None, should_raise: bool = False +) -> dict[str, Any] | str: + """Make a GitHub API request with common error handling. + + Args: + method: HTTP method (GET, POST, PATCH, etc.) + endpoint: API endpoint path (e.g., "pulls", "issues/123") + repo: Repository in "owner/repo" format + data: JSON data for request body + params: Query parameters for the request + + Returns: + Response JSON or error string + """ + if repo is None: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo: + return "Error: GITHUB_REPOSITORY environment variable not found" + + token = os.environ.get("GITHUB_TOKEN", "") + if not token: + return "Error: GITHUB_TOKEN environment variable not found" + + url = f"https://api.github.com/repos/{repo}/{endpoint}" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github.v3+json", + } + + try: + if method.upper() == "GET": + response = requests.get(url, headers=headers, params=params, timeout=30) + elif method.upper() == "POST": + response = requests.post(url, headers=headers, json=data, params=params, timeout=30) + else: + response = requests.request(method, url, headers=headers, json=data, params=params, timeout=30) + response.raise_for_status() + return response.json() # type: ignore[no-any-return] + except Exception as e: + if should_raise: + raise e + return f"Error {e!s}" + + +def check_should_call_write_api_or_record(func): + """Decorator that checks if a write api should be called, or if the tool should record to JSONL.""" + @wraps(func) + def wrapper(*args, **kwargs): + try: + if not _should_call_write_api(): + # Record the tool request to JSONL file + record_entry: GitHubOperation = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "function": func.__name__, + "args": args, + "kwargs": kwargs + } + + os.makedirs(".artifact", exist_ok=True) + with open(".artifact/write_operations.jsonl", "a") as f: + f.write(json.dumps(record_entry) + "\n") + + # Generate and return deferred message + params = dict(kwargs) + if args: + # Map positional args to parameter names from function signature + import inspect + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + for i, arg in enumerate(args): + if i < len(param_names): + params[param_names[i]] = arg + + deferred_msg = _generate_deferred_message(func.__name__, params) + console.print(Panel(escape(deferred_msg), title="[bold yellow]Operation Deferred", border_style="yellow")) + return deferred_msg + except Exception as e: + error_msg = f"Error checking permissions: {e!s}" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + + return func(*args, **kwargs) + return wrapper + + +def _generate_deferred_message(operation_name: str, params: dict[str, Any]) -> str: + """Generate a consistent deferred message for write operations. + + Args: + operation_name: Name of the operation being deferred + params: Parameters that would have been used for the operation + + Returns: + Formatted deferred message string + """ + if not params: + return f"Operation deferred: {operation_name}" + + # Format parameters, truncating long values + param_strs = [] + for key, value in params.items(): + if isinstance(value, str) and len(value) > 50: + param_strs.append(f"{key}='{value[:50]}...'") + elif isinstance(value, str): + param_strs.append(f"{key}='{value}'") + else: + param_strs.append(f"{key}={value}") + + return f"Operation deferred: {operation_name} - {', '.join(param_strs)}" + + +def _should_call_write_api() -> bool: + """Checks if GITHUB_WRITE environment variable is set to true. + + Returns: + bool: True if GITHUB_WRITE is set to 'true', False otherwise + """ + return os.environ.get("GITHUB_WRITE", "").lower() == "true" + + +# ============================================================================= +# WRITE FUNCTIONS (Functions that modify GitHub resources) +# ============================================================================= + +@tool +@log_inputs +@check_should_call_write_api_or_record +def create_issue(title: str, body: str = "", repo: str | None = None) -> str: + """Creates a new issue in the specified repository. + + Args: + title: The issue title + body: The issue body (optional) + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Result of the operation + """ + result = _github_request("POST", "issues", repo, {"title": title, "body": body}) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + message = f"Issue created: #{result['number']} - {result['html_url']}" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + +@tool +@log_inputs +@check_should_call_write_api_or_record +def update_issue( + issue_number: int, + title: str | None = None, + body: str | None = None, + state: str | None = None, + repo: str | None = None, +) -> str: + """Updates an issue's title, body, or state. + + Args: + issue_number: The issue number + title: New title (optional) + body: New body (optional) + state: New state - "open" or "closed" (optional) + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Result of the operation + """ + data = {} + if title is not None: + data["title"] = title + if body is not None: + data["body"] = body + if state is not None: + data["state"] = state + + if not data: + error_msg = "Error: At least one field (title, body, or state) must be provided" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + + result = _github_request("PATCH", f"issues/{issue_number}", repo, data) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + message = f"Issue updated: #{result['number']} - {result['html_url']}" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + +@tool +@log_inputs +@check_should_call_write_api_or_record +def add_issue_comment(issue_number: int, comment_text: str, repo: str | None = None) -> str: + """Adds a comment to an issue or pull request in the specified repository or GITHUB_REPOSITORY environment variable. + + Args: + issue_number: The issue or PR number to comment on + comment_text: The comment text + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Result of the operation + """ + result = _github_request("POST", f"issues/{issue_number}/comments", repo, {"body": comment_text}) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + message = f"Comment added successfully: {result['html_url']} (created: {result['created_at']})" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + +@tool +@log_inputs +@check_should_call_write_api_or_record +def create_pull_request(title: str, head: str, base: str, body: str = "", repo: str | None = None, fallback_issue_id: int | None = None) -> str: + """Creates a new pull request, or optionally comments on the fallback_issue_id for a link to create a pull request. + + Args: + title: The PR title + head: The branch containing changes + base: The branch to merge into + body: The PR body (optional) + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + fallback_issue_id: Issue ID to comment on if PR creation fails with an error (optional) + + Returns: + Result of the operation + """ + try: + result = _github_request( + "POST", + "pulls", + repo, + {"title": title, "head": head, "base": base, "body": body}, + should_raise=True + ) + + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + + message = f"Pull request created: #{result['number']} - {result['html_url']}" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + except Exception as e: + if fallback_issue_id is not None: + agent_message = "Failed to create pull request, commenting on issue instead." + console.print(Panel(escape(agent_message), title="[bold yellow]Fallback", border_style="yellow")) + repo_name = repo or os.environ.get("GITHUB_REPOSITORY", "") + query_params = urlencode({ + 'quick_pull': '1', + 'title': title, + 'body': body + }, quote_via=quote) + pr_link = f"https://github.com/{repo_name}/compare/{base}...{head}?{query_params}" + fallback_comment = f"Unable to create pull request via API. You can create it manually by clicking [here]({pr_link})." + add_issue_comment(fallback_issue_id, fallback_comment, repo) + return f"Unable to create pull request via API - posted a manual creation link as a comment on issue #{fallback_issue_id}" + else: + error_msg = f"Error: {e!s}" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + + +@tool +@log_inputs +@check_should_call_write_api_or_record +def update_pull_request( + pr_number: int, + title: str | None = None, + body: str | None = None, + base: str | None = None, + repo: str | None = None, +) -> str: + """Updates a pull request's title, body, or base branch. + + Args: + pr_number: The pull request number + title: New title (optional) + body: New body (optional) + base: New base branch (optional) + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Result of the operation + """ + data = {} + if title is not None: + data["title"] = title + if body is not None: + data["body"] = body + if base is not None: + data["base"] = base + + if not data: + error_msg = "Error: At least one field (title, body, or base) must be provided" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + + result = _github_request("PATCH", f"pulls/{pr_number}", repo, data) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + message = f"Pull request updated: #{result['number']} - {result['html_url']}" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + +@tool +@log_inputs +@check_should_call_write_api_or_record +def reply_to_review_comment(pr_number: int, comment_id: int, reply_text: str, repo: str | None = None) -> str: + """Replies to a pull request review comment. + + Args: + pr_number: The pull request number + comment_id: The review comment ID to reply to + reply_text: The reply text + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Result of the operation + """ + result = _github_request("POST", f"pulls/{pr_number}/comments/{comment_id}/replies", repo, {"body": reply_text}) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + message = f"Reply added to review comment: {result['html_url']}" + reply_details = f"Reply: {reply_text}\nURL: {result['html_url']}" + console.print(Panel(escape(reply_details), title="[bold green]✅ Reply Added", border_style="green")) + return message + + +# ============================================================================= +# READ FUNCTIONS (Functions that only read GitHub resources) +# ============================================================================= + +@tool +@log_inputs +def get_issue(issue_number: int, repo: str | None = None) -> str: + """Gets details of a specific issue. + + Args: + issue_number: The issue number + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Issue details + """ + result = _github_request("GET", f"issues/{issue_number}", repo) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + details = ( + f"#{result['number']} - {result['title']}\n" + f"State: {result['state']}\n" + f"Author: {result['user']['login']}\n" + f"URL: {result['html_url']}\n\n{result['body']}" + ) + console.print( + Panel( + escape(details), + title=f"[bold green]📋 Issue #{result['number']}", + border_style="blue", + ) + ) + return details + + +@tool +@log_inputs +def list_issues(state: str = "open", repo: str | None = None) -> str: + """Lists issues from the specified GitHub repository or GITHUB_REPOSITORY environment variable. + + Args: + state: Filter issues by state: "open", "closed", or "all" (default: "open") + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + String representation of the issues + """ + result = _github_request("GET", "issues", repo, params={"state": state}) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + # Filter out pull requests from issues list + issues = [issue for issue in result if "pull_request" not in issue] + if not issues: + message = f"No {state} issues found in {repo or os.environ.get('GITHUB_REPOSITORY')}" + console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) + return message + + table = Table(title=f"🐛 Issues ({state})", box=box.DOUBLE) + table.add_column("Issue #", style="cyan") + table.add_column("Title", style="white") + table.add_column("Author", style="green") + table.add_column("URL", style="blue") + + for issue in issues: + table.add_row( + f"#{issue['number']}", # type: ignore[index] + issue["title"], # type: ignore[index] + issue["user"]["login"], # type: ignore[index] + issue["html_url"], # type: ignore[index] + ) + + console.print(table) + + output = f"Issues ({state}) in {repo or os.environ.get('GITHUB_REPOSITORY')}:\n" + for issue in issues: + output += f"#{issue['number']} - {issue['title']} by {issue['user']['login']} - {issue['html_url']}\n" # type: ignore[index] + return output + + +@tool +@log_inputs +def get_issue_comments(issue_number: int, repo: str | None = None, since: str | None = None) -> str: + """Gets all comments for a specific issue. + + Args: + issue_number: The issue number + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + since: ISO 8601 timestamp to filter comments updated after this date (optional) + + Returns: + List of comments + """ + params = {"since": since} if since else None + result = _github_request("GET", f"issues/{issue_number}/comments", repo, params=params) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + if not result: + message = f"No comments found for issue #{issue_number}" + (f" updated after {since}" if since else "") + console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) + return message + + output = f"Comments for issue #{issue_number}:\n" + for comment in result: + output += f"{comment['user']['login']} - updated: {comment['updated_at']}\n{comment['body']}\n\n" # type: ignore[index] + + console.print(Panel(escape(output), title=f"[bold green]💬 Issue #{issue_number} Comments", border_style="blue")) + return output + + +@tool +@log_inputs +def get_pull_request(pr_number: int, repo: str | None = None) -> str: + """Gets details of a specific pull request. + + Args: + pr_number: The pull request number + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + Pull request details + """ + result = _github_request("GET", f"pulls/{pr_number}", repo) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + details = ( + f"#{result['number']} - {result['title']}\n" + f"State: {result['state']}\n" + f"Author: {result['user']['login']}\n" + f"Head: {result['head']['ref']} -> Base: {result['base']['ref']}\n" + f"URL: {result['html_url']}\n\n{result['body']}" + ) + console.print( + Panel( + escape(details), + title=f"[bold green]🔀 PR #{result['number']}", + border_style="blue", + ) + ) + return details + + +@tool +@log_inputs +def list_pull_requests(state: str = "open", repo: str | None = None) -> str: + """Lists pull requests from the specified GitHub repository or GITHUB_REPOSITORY environment variable. + + Args: + state: Filter PRs by state: "open", "closed", or "all" (default: "open") + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + + Returns: + String representation of the pull requests + """ + result = _github_request("GET", "pulls", repo, params={"state": state}) + if isinstance(result, str): + console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) + return result + + if not result: + message = f"No {state} pull requests found in {repo or os.environ.get('GITHUB_REPOSITORY')}" + console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) + return message + + table = Table(title=f"🔀 Pull Requests ({state})", box=box.DOUBLE) + table.add_column("PR #", style="cyan") + table.add_column("Title", style="white") + table.add_column("Author", style="green") + table.add_column("URL", style="blue") + + for pr in result: + table.add_row(f"#{pr['number']}", pr["title"], pr["user"]["login"], pr["html_url"]) # type: ignore[index] + + console.print(table) + + output = f"Pull Requests ({state}) in {repo or os.environ.get('GITHUB_REPOSITORY')}:\n" + for pr in result: + output += f"#{pr['number']} - {pr['title']} by {pr['user']['login']} - {pr['html_url']}\n" # type: ignore[index] + return output + + +@tool +@log_inputs +def get_pr_review_and_comments(pr_number: int, show_resolved: bool = False, repo: str | None = None, since: str | None = None) -> str: + """Gets all review threads and comments for a PR. + + Args: + pr_number: The pull request number + repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) + show_resolved: Whether to include resolved review threads (default: False) + since: ISO 8601 timestamp to filter comments/threads updated after this date (optional) + + Returns: + Formatted review threads and comments + """ + if repo is None: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo: + return "Error: GITHUB_REPOSITORY environment variable not found" + + token = os.environ.get("GITHUB_TOKEN", "") + if not token: + return "Error: GITHUB_TOKEN environment variable not found" + + owner, repo_name = repo.split("/") + + query = """ + query($owner: String!, $name: String!, $number: Int!) { + repository(owner: $owner, name: $name) { + pullRequest(number: $number) { + reviewThreads(first: 100) { + nodes { + isResolved + comments(first: 100) { + nodes { + id + fullDatabaseId + author { login } + body + updatedAt + path + line + startLine + diffHunk + replyTo { id } + pullRequestReview { + id + body + author { login } + updatedAt + } + } + } + } + } + comments(first: 100) { + nodes { + author { login } + body + updatedAt + } + } + } + } + } + """ + + variables = {"owner": owner, "name": repo_name, "number": pr_number} + + try: + response = requests.post( + "https://api.github.com/graphql", + headers={"Authorization": f"Bearer {token}"}, + json={"query": query, "variables": variables}, + timeout=30 + ) + response.raise_for_status() + data = response.json() + + if "errors" in data: + return f"GraphQL Error: {data['errors']}" + + pr_data = data["data"]["repository"]["pullRequest"] + + # Filter by since if provided + if since: + cutoff = datetime.fromisoformat(since.replace('Z', '+00:00')) + + # Filter review threads - if any comment in thread is newer, include entire thread + filtered_threads = [] + for thread in pr_data["reviewThreads"]["nodes"]: + has_newer_comment = any(datetime.fromisoformat(c['updatedAt'].replace('Z', '+00:00')) > cutoff + for c in thread["comments"]["nodes"]) + if has_newer_comment: + filtered_threads.append(thread) + pr_data["reviewThreads"]["nodes"] = filtered_threads + + # Filter general comments + pr_data["comments"]["nodes"] = [c for c in pr_data["comments"]["nodes"] + if datetime.fromisoformat(c['updatedAt'].replace('Z', '+00:00')) > cutoff] + + output = f"Review threads and comments for PR #{pr_number}:\n\n" + + # Group review threads by review ID + review_threads = {} + for thread in pr_data["reviewThreads"]["nodes"]: + if not show_resolved and thread["isResolved"]: + continue + + if thread["comments"]["nodes"]: + first_comment = thread["comments"]["nodes"][0] + review_id = first_comment.get("pullRequestReview", {}).get("id", "N/A") + + if review_id not in review_threads: + review_threads[review_id] = { + "review_data": first_comment.get("pullRequestReview", {}), + "threads": [] + } + + review_threads[review_id]["threads"].append(thread) + + # Display grouped review threads + for review_id, review_info in review_threads.items(): + review_data = review_info['review_data'] + output += f"📝 Review [Review ID: {review_id}]\n" + + # Always show review author and timestamps + if review_data.get('author'): + output += f" 👤 Review by {review_data['author']['login']} (updated: {review_data['updatedAt']})\n" + + # Show top-level review comment if it exists + if review_data.get('body'): + output += f" 📋 Review Comment:\n" + output += f" {review_data['body']}\n" + output += "\n" + + # Show all threads for this review + for thread in review_info["threads"]: + first_comment = thread["comments"]["nodes"][0] + line_info = f":{first_comment['line']}" if first_comment.get('line') else " (Comment on file)" + status = "✅ RESOLVED" if thread["isResolved"] else "🔄 OPEN" + + output += f" 📍 Thread ({status}): {first_comment['path']}{line_info}\n" + + # Show code context right after thread header + if first_comment.get('diffHunk') and first_comment.get('line'): + diff_lines = first_comment['diffHunk'].split('\n') + current_new_line = 0 + target_line = first_comment['line'] + start_line = first_comment.get('startLine') or target_line + + output += f" Code context (lines {start_line}-{target_line}):\n" + + for diff_line in diff_lines: + if diff_line.startswith('@@'): + parts = diff_line.split(' ') + if len(parts) >= 3: + new_start = parts[2].split(',')[0][1:] + current_new_line = int(new_start) - 1 + elif diff_line.startswith('+'): + current_new_line += 1 + if start_line <= current_new_line <= target_line: + output += f" +{current_new_line}: {diff_line[1:]}\n" + elif diff_line.startswith('-'): + pass + elif diff_line.startswith(' '): + current_new_line += 1 + if start_line <= current_new_line <= target_line: + output += f" {current_new_line}: {diff_line[1:]}\n" + output += "\n" + + # Group comments by reply relationships + comments = thread["comments"]["nodes"] + root_comments = [c for c in comments if not c.get('replyTo')] + + for root_comment in root_comments: + output += f" 💬 {root_comment['author']['login']} (updated: {root_comment['updatedAt']}) [Comment ID: {root_comment['fullDatabaseId']}]:\n" + output += f" {root_comment['body']}\n" + + # Find and show replies to this comment + replies = [c for c in comments if c.get('replyTo') and c['replyTo'].get('id') == root_comment['id']] + if replies: + for reply in replies: + output += f" ↳ {reply['author']['login']} (updated: {reply['updatedAt']}):\n" + output += f" {reply['body']}\n" + + output += "\n" + output += "\n" + + # General comments + if pr_data["comments"]["nodes"]: + for comment in pr_data["comments"]["nodes"]: + output += f"💬 Comment\n" + output += f" 👤 Comment by {comment['author']['login']} (updated: {comment['updatedAt']})\n" + output += f" 📝 Comment:\n" + output += f" {comment['body']}\n\n" + + console.print(Panel(escape(output), title=f"[bold green]PR #{pr_number} Review Data", border_style="blue")) + return output + + except Exception as e: + error_msg = f"Error: {e!s}\n\nStack trace:\n{traceback.format_exc()}" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg diff --git a/.github/scripts/python/handoff_to_user.py b/.github/scripts/python/handoff_to_user.py new file mode 100644 index 000000000..07ad331f1 --- /dev/null +++ b/.github/scripts/python/handoff_to_user.py @@ -0,0 +1,34 @@ +from rich.markup import escape +from rich.panel import Panel +from strands import tool +from strands.types.tools import ToolContext +from strands_tools.utils import console_util + +@tool(context=True) +def handoff_to_user(message: str, tool_context: ToolContext) -> str: + """ + Hand off control to the user with a message. + + Args: + message: The message to give to the user + + Returns: + The users response after handing back control + """ + console = console_util.create() + + console.print( + Panel( + escape(message), + title="[bold yellow]🤝 Handoff to User", + border_style="yellow", + ) + ) + + request_state = { + "stop_event_loop": True + } + tool_context.invocation_state["request_state"] = request_state + + # Return an empty string as this will break out of the event loop + return "" \ No newline at end of file diff --git a/.github/scripts/python/notebook.py b/.github/scripts/python/notebook.py new file mode 100644 index 000000000..0b5ba2ace --- /dev/null +++ b/.github/scripts/python/notebook.py @@ -0,0 +1,337 @@ +"""Notebook management tool for Strands Agents. + +This module provides comprehensive notebook operations for managing text-based notebooks +within agent workflows. Enables persistent note-taking, documentation, and context +preservation across agent sessions. + +Key Features: +1. Create and manage multiple named notebooks +2. Write content using string replacement or line insertion +3. Read entire notebooks or specific line ranges +4. List all available notebooks with metadata +5. Clear notebook contents when needed +6. Rich console output with formatted panels and tables +7. Agent state persistence for session continuity + +Usage Examples: +```python +from strands import Agent +from tools.notebook import notebook + +agent = Agent(tools=[notebook]) + +# Create a new notebook with initial content +result = agent.tool.notebook( + mode="create", + name="research_notes", + new_str="# Research Notes\n\nKey findings and observations..." +) + +# Write to notebook using line insertion +result = agent.tool.notebook( + mode="write", + name="research_notes", + insert_line=-1, # Append to end + new_str="- Important discovery about AI behavior patterns" +) + +# Read specific lines from notebook +result = agent.tool.notebook( + mode="read", + name="research_notes", + read_range=[1, 5] # Read first 5 lines +) + +# Replace text in notebook +result = agent.tool.notebook( + mode="write", + name="research_notes", + old_str="[ ] Todo item", + new_str="[x] Completed todo item" +) + +# List all notebooks +result = agent.tool.notebook(mode="list") + +# Clear notebook contents +result = agent.tool.notebook(mode="clear", name="research_notes") +``` +""" + +from typing import Any, Literal + +from rich import box +from rich.markup import escape +from rich.panel import Panel +from rich.table import Table +from strands import ToolContext, tool +from strands_tools.utils import console_util + + +@tool(context=True) +def notebook( + mode: Literal["create", "list", "read", "write", "clear"], + name: str = "default", + read_range: list[int] | None = None, + old_str: str | None = None, + new_str: str | None = None, + insert_line: str | int | None = None, + tool_context: ToolContext | None = None, +) -> str: + """ + Notebook tool for managing text notebooks. + + This tool provides a comprehensive interface for creating, reading, writing, listing, + and deleting text notebooks. Start writing notes in the default notebook which is avaiable + from the start, or create new notebooks to record notes on additional topics or tasks. + + Command Details: + -------------- + 1. write: + • Supports two types of write operations: + - String replacement: Uses old_str and new_str parameters + - Line insertion: Uses insert_line and new_str parameters + + 2. read: + • Reads contents of a notebook + • Supports reading specific line numbers with read_range parameter + + 3. create: + • Creates a new notebook with the specified name + • Optionally initializes with content using new_str parameter + • Defaults to empty content if new_str not provided + + 4. list: + • Lists all available notebook names + • Returns comma-separated list of notebook names + + 5. clear: + • Clears the contents of a notebook + + Args: + mode: The operation to perform: `create`, `list`, `read`, `write`, `clear`. + name: Name of the notebook to operate on. Defaults to "default". + read_range: Optional parameter of `view` command. Line range to show [start, end]. Supports negative indices. + old_str: String to replace in write mode when doing text replacement. + new_str: New string for replacement or insertion operations. + insert_line: Line number (int) or search text (str) for insertion point in write mode. + Supports negative indices. + + Returns: + Dict containing status and response content in the format: + { + "status": "success|error", + "content": [{"text": "Response message"}] + } + + Success case: Returns details about the operation performed + Error case: Returns information about what went wrong + + Examples: + 1. Create a notebook: + notebook(mode="create", name="notes") + + 2. List all notebooks: + notebook(mode="list") + + 3. Read entire notebook: + notebook(mode="read", name="notes") + + 4. Read specific lines: + notebook(mode="read", name="notes", read_range=[1, 5]) + + 5. Replace text: + notebook(mode="write", name="notes", old_str="[] Update the calendar", new_str="[x] Update the calendar") + + 6. Insert text after line 5: + notebook(mode="write", name="notes", insert_line=5, new_str="inserted text") + + 7. Insert text at end of notebook: + notebook(mode="write", name="notes", insert_line=-1, new_str="Appended text") + + 7. Insert text after finding a line: + notebook(mode="write", name="notes", insert_line="def function", new_str="# comment") + + 8. Clear notebook: + notebook(mode="clear", name="notes") + """ + console = console_util.create() + if tool_context is None: + raise ValueError("Tool context is required") + agent = tool_context.agent + + if agent.state.get("notebooks") is None: + agent.state.set("notebooks", {"default": ""}) + + notebooks: dict[str, Any] = agent.state.get("notebooks") + + if mode == "create": + notebooks[name] = new_str if new_str else "" + message = f"Created notebook '{name}'" + (" with specified content" if new_str else " (empty)") + console.print( + Panel( + escape(message + f":\n{new_str}" if new_str else ""), + title="[bold green]Success", + border_style="green", + ) + ) + agent.state.set("notebooks", notebooks) + return message + + elif mode == "list": + table = Table(title="📚 Available Notebooks", box=box.DOUBLE) + table.add_column("Name", style="cyan") + table.add_column("Lines", style="yellow") + table.add_column("Status", style="green") + + for nb_name in notebooks.keys(): + line_count = len(notebooks[nb_name].split("\n")) if notebooks[nb_name] else 0 + status = "Empty" if line_count == 0 else "Has content" + table.add_row(nb_name, str(line_count), status) + + console.print(table) + return f"Notebooks: {', '.join(notebooks.keys())}" + + elif mode == "read": + if name not in notebooks: + error_msg = f"Notebook '{name}' not found" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + + content = notebooks[name] + if read_range: + lines = content.split("\n") + start, end = read_range + # Handle negative indices + if start < 0: + start = len(lines) + start + 1 + if end < 0: + end = len(lines) + end + 1 + + selected_lines = [] + for line_num in range(start, end + 1): + if 1 <= line_num <= len(lines): + selected_lines.append(f"{line_num}: {lines[line_num - 1]}") + + result = "\n".join(selected_lines) if selected_lines else "No valid lines found" + console.print( + Panel( + escape(result), + title=f"[bold green]📖 {name} (lines {start}-{end})", + border_style="blue", + ) + ) + return result + + result = content if content else f"Notebook '{name}' is empty" + console.print(Panel(escape(result), title=f"[bold green]📖 {name}", border_style="blue")) + return result + + elif mode == "write": + if name not in notebooks: + error_msg = f"Notebook '{name}' not found" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + + # String replacement + if old_str is not None and new_str is not None: + if old_str not in notebooks[name]: + error_msg = f"String '{old_str}' not found in notebook '{name}'" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + + notebooks[name] = notebooks[name].replace(old_str, new_str) + agent.state.set("notebooks", notebooks) + + # Create git-style diff + old_lines = old_str.split("\n") + new_lines = new_str.split("\n") + diff_lines = [] + + for line in old_lines: + diff_lines.append(f"[red]-{escape(line)}[/red]") + for line in new_lines: + diff_lines.append(f"[green]+{escape(line)}[/green]") + + diff_content = "\n".join(diff_lines) + console.print(Panel(diff_content, title="[bold yellow]📝 Diff", border_style="yellow")) + + message = f"Replaced text in notebook '{name}'" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message + + # Line insertion + elif insert_line is not None and new_str is not None: + lines = notebooks[name].split("\n") + + # Check if string represents a number first + if isinstance(insert_line, str): + try: + insert_line = int(insert_line) + except ValueError: + pass # Keep as string for text search + + if isinstance(insert_line, str): + line_num = -1 + for i, line in enumerate(lines): + if insert_line in line: + line_num = i + break + if line_num == -1: + error_msg = f"Text '{insert_line}' not found in notebook '{name}'" + console.print( + Panel( + escape(error_msg), + title="[bold red]Error", + border_style="red", + ) + ) + raise ValueError(error_msg) + else: + # Handle negative indices + if insert_line < 0: + line_num = len(lines) + insert_line + else: + line_num = insert_line - 1 + + if 0 <= line_num <= len(lines): + lines.insert(line_num + 1, new_str) + notebooks[name] = "\n".join(lines) + agent.state.set("notebooks", notebooks) + message = f"Inserted text at line {line_num + 2} in notebook '{name}'" + console.print( + Panel( + escape(message), + title="[bold green]Success", + border_style="green", + ) + ) + console.print( + Panel( + escape(notebooks[name]), + title=f"[bold blue]📝 {name} Content", + border_style="blue", + ) + ) + return message + else: + error_msg = f"Line number {insert_line} out of range" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + + # No valid operation provided + else: + error_msg = "No valid write operation specified" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + + elif mode == "clear": + if name not in notebooks: + error_msg = f"Notebook '{name}' not found" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + raise ValueError(error_msg) + notebooks[name] = "" + agent.state.set("notebooks", notebooks) + message = f"Cleared notebook '{name}'" + console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) + return message diff --git a/.github/scripts/python/requirements.txt b/.github/scripts/python/requirements.txt new file mode 100644 index 000000000..1ca2770ff --- /dev/null +++ b/.github/scripts/python/requirements.txt @@ -0,0 +1,8 @@ +# Strands packages - only what we need +strands-agents +strands-agents-tools + +# Additional dependencies for our specific tools +colorama +rich +requests>=2.28.0 \ No newline at end of file diff --git a/.github/scripts/python/str_replace_based_edit_tool.py b/.github/scripts/python/str_replace_based_edit_tool.py new file mode 100644 index 000000000..69c92c206 --- /dev/null +++ b/.github/scripts/python/str_replace_based_edit_tool.py @@ -0,0 +1,230 @@ +"""Text editor tool for Strands Agents. + +A minimal implementation of Claude's text editor tool that supports: +- view: Read file contents or list directory contents +- str_replace: Replace text in files +- create: Create new files +- insert: Insert text at specific line numbers + +Based on Claude's text_editor_20250728 specification. +""" + +from pathlib import Path +from typing import List, Optional + +from rich.markup import escape +from rich.panel import Panel +from strands import tool +from strands_tools.utils import console_util + +console = console_util.create() + + +@tool +def str_replace_based_edit_tool( + command: str, + path: str, + old_str: str | None = None, + new_str: str | None = None, + file_text: str | None = None, + insert_line: str | None = None, + view_range: list[int] | None = None, +) -> str: + """Text editor tool for viewing and modifying files. + + Args: + command: The command to execute ("view", "str_replace", "create", "insert") + path: Path to the file or directory + old_str: Text to replace (for str_replace command) + new_str: Replacement text (for str_replace and insert commands) + file_text: Content for new file (for create command) + insert_line: Line number to insert after (for insert command) + view_range: [start_line, end_line] for viewing specific lines (for view command) + + Returns: + Result of the operation + """ + try: + console.print(Panel(f"Command: {command}, Path: {path}", title="[bold blue]Text Editor", border_style="blue")) + + if command == "view": + return _handle_view(path, view_range) + elif command == "str_replace": + if old_str is None or new_str is None: + error_msg = "Error: str_replace requires both old_str and new_str parameters" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + return _handle_str_replace(path, old_str, new_str) + elif command == "create": + if file_text is None: + error_msg = "Error: create requires file_text parameter" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + return _handle_create(path, file_text) + elif command == "insert": + if new_str is None or insert_line is None: + error_msg = "Error: insert requires both new_str and insert_line parameters" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + return _handle_insert(path, new_str, insert_line) + else: + error_msg = f"Error: Unknown command '{command}'" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + except Exception as e: + error_msg = f"Error: {str(e)}" + console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) + return error_msg + + +def _handle_view(path: str, view_range: Optional[List[int]] = None) -> str: + """Handle view command to read files or list directories.""" + path_obj = Path(path) + + if not path_obj.exists(): + return f"Error: Path '{path}' does not exist" + + if path_obj.is_dir(): + # List directory contents + try: + items = [] + for item in sorted(path_obj.iterdir()): + if item.is_dir(): + items.append(f"{item.name}/") + else: + items.append(item.name) + return "\n".join(items) + except PermissionError: + return f"Error: Permission denied accessing directory '{path}'" + + elif path_obj.is_file(): + # Read file contents + try: + with open(path_obj, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # Apply view_range if specified + if view_range: + start_line, end_line = view_range + # Convert to 0-based indexing + start_idx = max(0, start_line - 1) if start_line > 0 else 0 + end_idx = len(lines) if end_line == -1 else min(len(lines), end_line) + lines = lines[start_idx:end_idx] + start_line_num = start_idx + 1 + else: + start_line_num = 1 + + # Add line numbers + numbered_lines = [] + for i, line in enumerate(lines): + line_num = start_line_num + i + numbered_lines.append(f"{line_num}: {line.rstrip()}") + + return "\n".join(numbered_lines) + except UnicodeDecodeError: + return f"Error: Cannot read '{path}' - file appears to be binary" + except PermissionError: + return f"Error: Permission denied reading file '{path}'" + + else: + return f"Error: '{path}' is not a regular file or directory" + + +def _handle_str_replace(path: str, old_str: str, new_str: str) -> str: + """Handle str_replace command to replace text in a file.""" + path_obj = Path(path) + + if not path_obj.exists(): + return f"Error: File '{path}' does not exist" + + if not path_obj.is_file(): + return f"Error: '{path}' is not a file" + + try: + # Read file content + with open(path_obj, 'r', encoding='utf-8') as f: + content = f.read() + + # Check if old_str exists + if old_str not in content: + return f"Error: Text '{old_str}' not found in file" + + # Count occurrences + count = content.count(old_str) + if count > 1: + return f"Error: Text '{old_str}' appears {count} times in file. Please be more specific." + + # Replace text + new_content = content.replace(old_str, new_str) + + # Write back to file + with open(path_obj, 'w', encoding='utf-8') as f: + f.write(new_content) + + success_msg = f"Successfully replaced text in '{path}'" + console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) + return success_msg + + except UnicodeDecodeError: + return f"Error: Cannot modify '{path}' - file appears to be binary" + except PermissionError: + return f"Error: Permission denied modifying file '{path}'" + + +def _handle_create(path: str, file_text: str) -> str: + """Handle create command to create a new file.""" + path_obj = Path(path) + + # Create parent directories if they don't exist + path_obj.parent.mkdir(parents=True, exist_ok=True) + + try: + with open(path_obj, 'w', encoding='utf-8') as f: + f.write(file_text) + + success_msg = f"Successfully created file '{path}'" + console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) + return success_msg + + except PermissionError: + return f"Error: Permission denied creating file '{path}'" + + +def _handle_insert(path: str, new_str: str, insert_line: int) -> str: + """Handle insert command to insert text at a specific line.""" + path_obj = Path(path) + + if not path_obj.exists(): + return f"Error: File '{path}' does not exist" + + if not path_obj.is_file(): + return f"Error: '{path}' is not a file" + + try: + # Read file lines + with open(path_obj, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # Insert new text + if insert_line == 0: + # Insert at beginning + lines.insert(0, new_str + '\n') + elif insert_line >= len(lines): + # Insert at end + lines.append(new_str + '\n') + else: + # Insert after specified line (1-based indexing) + lines.insert(insert_line, new_str + '\n') + + # Write back to file + with open(path_obj, 'w', encoding='utf-8') as f: + f.writelines(lines) + + success_msg = f"Successfully inserted text in '{path}' at line {insert_line + 1}" + console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) + return success_msg + + except UnicodeDecodeError: + return f"Error: Cannot modify '{path}' - file appears to be binary" + except PermissionError: + return f"Error: Permission denied modifying file '{path}'" \ No newline at end of file diff --git a/.github/scripts/python/write_executor.py b/.github/scripts/python/write_executor.py new file mode 100755 index 000000000..6d3b6b84d --- /dev/null +++ b/.github/scripts/python/write_executor.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Write Executor Script for GitHub Operations. + +This script reads JSONL artifact files containing deferred GitHub operations +and executes them using functions from github_tools.py. It's designed to run +after the strands-agent-runner to publish any write commands or commits. +""" + +import argparse +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict + +from github_tools import GitHubOperation + +# Import write only github_tools functions for dynamic execution +from github_tools import ( + create_issue, + update_issue, + add_issue_comment, + create_pull_request, + update_pull_request, + reply_to_review_comment, +) + +# Configure structured logging +logging.basicConfig( + format="%(levelname)s | %(name)s | %(message)s", + handlers=[logging.StreamHandler()], + level=logging.INFO +) +logger = logging.getLogger("write_executor") + + +def get_function_mapping() -> Dict[str, Any]: + """Get mapping of function names to actual functions.""" + return { + create_issue.tool_name: create_issue, + update_issue.tool_name: update_issue, + add_issue_comment.tool_name: add_issue_comment, + create_pull_request.tool_name: create_pull_request, + update_pull_request.tool_name: update_pull_request, + reply_to_review_comment.tool_name: reply_to_review_comment, + } + + +def process_jsonl_file(file_path: Path, default_issue_id: int | None = None): + """Process JSONL file and execute operations. + + Args: + file_path: Path to the JSONL artifact file + default_issue_id: Default issue ID to use for fallback operations + + Returns: + Tuple of (total_operations, successful_operations, failed_operations) + """ + function_map = get_function_mapping() + + logger.info(f"Starting JSONL processing: {file_path}") + total_ops = 0 + with open(file_path, 'r') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + total_ops += 1 + logger.info(f"Processing operation {total_ops} (line {line_num})") + + try: + # Parse JSONL entry + operation: GitHubOperation = json.loads(line) + func_name = operation.get("function") + args = operation.get('args', []) + kwargs = operation.get('kwargs', {}) + + if not func_name: + logger.error(f"Line {line_num}: Missing function name") + continue + + # Get function from mapping + if func_name not in function_map: + logger.error(f"Line {line_num}: Unknown function '{func_name}'") + continue + + func = function_map[func_name] + + # Set default issue ID for create_pull_request if not already set + if func_name == "create_pull_request" and default_issue_id and not kwargs.get("fallback_issue_id"): + kwargs["fallback_issue_id"] = default_issue_id + + # Execute function + logger.info(f"Executing {func_name} with args={args}, kwargs={kwargs}") + result = func(*args, **kwargs) + + logger.info(f"Line {line_num}: Operation {func_name} completed successfully") + logger.info(f"Function output: {str(result)}") + + except Exception as e: + logger.error(f"Line {line_num}: Execution error - {e}") + + + logger.info(f"JSONL processing completed.") + + +def main(): + """Main entry point for the write executor script.""" + parser = argparse.ArgumentParser( + description="Execute deferred GitHub operations from JSONL artifact files" + ) + parser.add_argument( + "artifact_file", + help="Path to JSONL artifact file containing deferred operations" + ) + parser.add_argument( + "--issue-id", + type=int, + help="Default issue ID to use for fallback operations" + ) + + args = parser.parse_args() + artifact_path = Path(args.artifact_file) + + logger.info(f"Write executor started with artifact file: {artifact_path}") + if args.issue_id: + logger.info(f"Default issue ID set to: {args.issue_id}") + + # Check if file exists + if not artifact_path.exists(): + logger.warning(f"Artifact file not found: {artifact_path}") + logger.warning("No deferred operations to execute") + return + + # Check if file is empty + if artifact_path.stat().st_size == 0: + logger.info("Artifact file is empty") + logger.info("No deferred operations to execute") + return + + # Set environment to enable write operations + os.environ['GITHUB_WRITE'] = 'true' + logger.info("GitHub write mode enabled") + + logger.info(f"Processing deferred operations from: {artifact_path}") + + # Process the JSONL file + process_jsonl_file(artifact_path, args.issue_id) + +if __name__ == "__main__": + main() diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml new file mode 100644 index 000000000..803f19e48 --- /dev/null +++ b/.github/workflows/strands-command.yml @@ -0,0 +1,184 @@ +name: Strands Command Handler + +on: + issue_comment: + types: [created] + workflow_dispatch: + inputs: + issue_id: + description: 'Issue ID to process (can be issue or PR number)' + required: true + type: string + command: + description: 'Strands command to execute' + required: false + type: string + default: '' + session_id: + description: 'Optional session ID to use' + required: false + type: string + default: '' + +jobs: + authorization-check: + if: startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch' + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.collab-check.outputs.result || steps.auto-approve.outputs.result }} + steps: + - name: Collaborator Check + if: github.event_name != 'workflow_dispatch' + uses: actions/github-script@v8 + id: collab-check + with: + result-encoding: string + script: | + try { + const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ + owner: context.repo.owner, + repo: context.repo.repo, + username: context.payload.comment.user.login, + }); + const permission = permissionResponse.data.permission; + const hasWriteAccess = ['write', 'admin'].includes(permission); + if (!hasWriteAccess) { + console.log(`User ${context.payload.comment.user.login} does not have write access to the repository (permission: ${permission})`); + return "manual-approval" + } else { + console.log(`Verified ${context.payload.comment.user.login} has write access. Auto Approving strands command.`) + return "auto-approve" + } + } catch (error) { + console.log(`${context.payload.comment.user.login} does not have write access. Requiring Manual Approval to run strands command.`) + return "manual-approval" + } + + - name: Auto-approve for workflow dispatch + if: github.event_name == 'workflow_dispatch' + id: auto-approve + uses: actions/github-script@v8 + with: + result-encoding: string + script: | + return "auto-approve" + + setup-and-process: + needs: [authorization-check] + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + contents: write + issues: write + pull-requests: write + runs-on: ubuntu-latest + outputs: + branch: ${{ steps.process.outputs.branch_name }} + session_id: ${{ steps.process.outputs.session_id }} + system_prompt: ${{ steps.process.outputs.system_prompt }} + prompt: ${{ steps.process.outputs.prompt }} + steps: + - name: Add strands-running label + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + await github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: ${{ inputs.issue_id || github.event.issue.number }}, + labels: ['strands-running'] + }); + + - name: Checkout repository + uses: actions/checkout@v6 + with: + sparse-checkout: | + .github + + # Outputs: branch_name, session_id, system_prompt, prompt + - name: Process input + id: process + uses: actions/github-script@v8 + with: + script: | + const processInput = require('./.github/scripts/javascript/process-input.cjs'); + await processInput(context, github, core, { + issue_id: '${{ inputs.issue_id }}', + command: '${{ inputs.command }}', + session_id: '${{ inputs.session_id }}' + }); + + execute-readonly: + needs: [setup-and-process] + permissions: + contents: read + issues: read + pull-requests: read + id-token: write # Required for OIDC + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + sparse-checkout: | + .github + + - name: Run Strands Agent + id: agent-runner + uses: ./.github/actions/strands-agent-runner + with: + system_prompt: ${{ needs.setup-and-process.outputs.system_prompt }} + session_id: ${{ needs.setup-and-process.outputs.session_id }} + task_prompt: ${{ needs.setup-and-process.outputs.prompt }} + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + sessions_bucket: ${{ secrets.AGENT_SESSIONS_BUCKET }} + write_permission: 'false' + ref: ${{ needs.setup-and-process.outputs.branch }} + + execute-write: + needs: [setup-and-process, execute-readonly] + permissions: + contents: write + issues: write + pull-requests: write + id-token: write # Required for OIDC + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + sparse-checkout: | + .github + + - name: Execute write operations + uses: ./.github/actions/strands-write-executor + with: + ref: ${{ needs.setup-and-process.outputs.branch }} + issue_id: ${{ inputs.issue_id || github.event.issue.number }} + + + cleanup: + needs: [authorization-check, setup-and-process, execute-readonly, execute-write] + if: always() + permissions: + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - name: Remove strands-running label + uses: actions/github-script@v8 + with: + script: | + try { + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: ${{ inputs.issue_id || github.event.issue.number }}, + name: 'strands-running' + }); + } catch (error) { + console.log('Label removal failed (may not exist):', error.message); + } From db01eeea300f31afe4fd9f590ee2a5fad440c0db Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 2 Jan 2026 10:16:30 -0500 Subject: [PATCH 036/279] feat: allow hooks to retry model invocations on exceptions (#1405) Users need the ability to retry model calls on arbitrary exceptions beyond just ModelThrottledException, and also retry based on response validation. This feature adds a low-level mechanism that enables that and more by letting hooks implement custom retry logic for both exceptions and successful responses. --------- Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 43 ++-- src/strands/hooks/events.py | 19 ++ tests/strands/agent/test_agent_hooks.py | 299 +++++++++++++++++++++++- 3 files changed, 347 insertions(+), 14 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f25057e4d..fcb530a0d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -351,16 +351,25 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) - await agent.hooks.invoke_callbacks_async( - AfterModelCallEvent( - agent=agent, - stop_response=AfterModelCallEvent.ModelStopResponse( - stop_reason=stop_reason, - message=message, - ), - ) + after_model_call_event = AfterModelCallEvent( + agent=agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + stop_reason=stop_reason, + message=message, + ), ) + await agent.hooks.invoke_callbacks_async(after_model_call_event) + + # Check if hooks want to retry the model call + if after_model_call_event.retry: + logger.debug( + "stop_reason=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + stop_reason, + attempt + 1, + ) + continue # Retry the model call + if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) @@ -372,12 +381,20 @@ async def _handle_model_execution( if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - await agent.hooks.invoke_callbacks_async( - AfterModelCallEvent( - agent=agent, - exception=e, - ) + after_model_call_event = AfterModelCallEvent( + agent=agent, + exception=e, ) + await agent.hooks.invoke_callbacks_async(after_model_call_event) + + # Check if hooks want to retry the model call + if after_model_call_event.retry: + logger.debug( + "exception=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + type(e).__name__, + attempt + 1, + ) + continue # Retry the model call if isinstance(e, ModelThrottledException): if attempt + 1 == MAX_ATTEMPTS: diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index ebc508f24..5e11524d1 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -200,9 +200,24 @@ class AfterModelCallEvent(HookEvent): Note: This event is not fired for invocations to structured_output. + Model Retrying: + When ``retry_model`` is set to True by a hook callback, the agent will discard + the current model response and invoke the model again. This has important + implications for streaming consumers: + + - Streaming events from the discarded response will have already been emitted + to callers before the retry occurs. Agent invokers consuming streamed events + should be prepared to handle this scenario, potentially by tracking retry state + or implementing idempotent event processing + - The original model message is thrown away internally and not added to the + conversation history + Attributes: stop_response: The model response data if invocation was successful, None if failed. exception: Exception if the model invocation failed, None if successful. + retry: Whether to retry the model invocation. Can be set by hook callbacks + to trigger a retry. When True, the current response is discarded and the + model is called again. Defaults to False. """ @dataclass @@ -219,6 +234,10 @@ class ModelStopResponse: stop_response: Optional[ModelStopResponse] = None exception: Optional[Exception] = None + retry: bool = False + + def _can_write(self, name: str) -> bool: + return name == "retry" @property def should_reverse_callbacks(self) -> bool: diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index d82329e95..00b9d368a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY, Mock +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest from pydantic import BaseModel @@ -16,6 +16,7 @@ MessageAddedEvent, ) from strands.types.content import Messages +from strands.types.exceptions import ModelThrottledException from strands.types.tools import ToolResult, ToolUse from tests.fixtures.mock_hook_provider import MockHookProvider from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -101,6 +102,12 @@ class User(BaseModel): return User(name="Jane Doe", age=30) +@pytest.fixture +def mock_sleep(): + with patch.object(strands.event_loop.event_loop.asyncio, "sleep", new_callable=AsyncMock) as mock: + yield mock + + def test_agent__init__hooks(): """Verify that the AgentInitializedEvent is emitted on Agent construction.""" hook_provider = MockHookProvider(event_types=[AgentInitializedEvent]) @@ -299,3 +306,293 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert next(events) == AfterInvocationEvent(agent=agent) assert len(agent.messages) == 0 # no new messages added + + +@pytest.mark.asyncio +async def test_hook_retry_on_successful_call(): + """Test that hooks can retry even on successful model calls based on response content.""" + + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Short"}], + }, + { + "role": "assistant", + "content": [{"text": "This is a much longer and more detailed response"}], + }, + ] + ) + + # Hook that retries if response is too short + class MinLengthRetryHook: + def __init__(self, min_length=10): + self.min_length = min_length + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + + # Check successful responses for minimum length + if event.stop_response: + message = event.stop_response.message + text_content = "".join(block.get("text", "") for block in message.get("content", [])) + + if len(text_content) < self.min_length: + event.retry = True + + retry_hook = MinLengthRetryHook(min_length=10) + agent = Agent(model=mock_provider, hooks=[retry_hook]) + + result = agent("Generate a response") + + # Verify hook was called twice (once for short response, once for long) + assert retry_hook.call_count == 2 + + # Verify final result is the longer response + assert result.message["content"][0]["text"] == "This is a much longer and more detailed response" + + +@pytest.mark.asyncio +async def test_hook_retry_on_exception_basic(alist, mock_sleep): + """Test that hooks can retry model calls on exceptions.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("First attempt fails"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success after retry"}], + }, + ] + ).stream([]), + ] + + # Hook that enables retry on CustomException + class RetryHook: + def __init__(self): + self.after_model_call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.after_model_call_count += 1 + if event.exception and isinstance(event.exception, CustomException): + event.retry = True + + retry_hook = RetryHook() + agent = Agent(model=model, hooks=[retry_hook]) + + result = agent("Test retry") + + # Verify the hook was called twice (once for failure, once for success) + assert retry_hook.after_model_call_count == 2 + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success after retry" + + +@pytest.mark.asyncio +async def test_hook_retry_not_set_on_success(alist): + """Test that model is not retried when hook doesn't set retry_model on success.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "First successful response"}], + }, + ] + ) + + # Hook that tries to set retry_model=True even on success + class NoRetryHook: + def __init__(self): + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + # Try to set retry even on success + # Don't set retry_model (leave it as False) + + retry_hook = NoRetryHook() + agent = Agent(model=mock_provider, hooks=[retry_hook]) + + result = agent("Test no retry when not set") + + # Should only be called once since retry_model was not set + assert retry_hook.call_count == 1 + assert result.message["content"][0]["text"] == "First successful response" + + +@pytest.mark.asyncio +async def test_hook_retry_with_limit(alist, mock_sleep): + """Test that hooks can control retry limits.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("Attempt 1 fails"), + CustomException("Attempt 2 fails"), + CustomException("Attempt 3 fails"), + ] + + # Hook that allows max 2 retries + class LimitedRetryHook: + def __init__(self, max_retries=2): + self.max_retries = max_retries + self.retry_count = 0 + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + if event.exception and isinstance(event.exception, CustomException): + if self.retry_count < self.max_retries: + self.retry_count += 1 + event.retry = True + # else: let exception propagate + + retry_hook = LimitedRetryHook(max_retries=2) + agent = Agent(model=model, hooks=[retry_hook]) + + with pytest.raises(CustomException, match="Attempt 3 fails"): + await agent("Test limited retries") + + # Should be called 3 times: initial + 2 retries + assert retry_hook.call_count == 3 + assert retry_hook.retry_count == 2 + + +@pytest.mark.asyncio +async def test_hook_retry_multiple_hooks(alist, mock_sleep): + """Test that multiple hooks can modify retry_model and last one wins.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("First attempt fails"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success"}], + }, + ] + ).stream([]), + ] + + async def retry_enabler(event: AfterModelCallEvent): + if event.exception: + event.retry = True + + async def another_retry_enabler(event: AfterModelCallEvent): + if event.exception: + event.retry = True + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, retry_enabler) + agent.hooks.add_callback(AfterModelCallEvent, another_retry_enabler) + + result = agent("Test multiple hooks") + + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success" + + +@pytest.mark.asyncio +async def test_hook_retry_last_hook_wins(alist, mock_sleep): + """Test that when multiple hooks set retry_model, the last-called hook wins. + + Note: AfterModelCallEvent callbacks are invoked in reverse order, so the first + registered hook is called last. + """ + + class CustomException(Exception): + pass + + call_count = [0] + + def mock_stream(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise CustomException("First attempt fails") + else: + raise CustomException(f"Should not be called (call {call_count[0]})") + + model = MagicMock() + model.stream = mock_stream + + async def retry_enabler(event: AfterModelCallEvent): + """Called first due to reverse order.""" + if event.exception: + event.retry = True + + async def retry_disabler(event: AfterModelCallEvent): + """Called last, so it wins.""" + if event.exception: + event.retry = False + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, retry_disabler) # Registered first, called last + agent.hooks.add_callback(AfterModelCallEvent, retry_enabler) # Registered second, called first + + # Should raise exception since last-called hook disabled retry + with pytest.raises(CustomException, match="First attempt fails"): + agent("Test last hook wins") + + # Verify stream was only called once + assert call_count[0] == 1 + + +@pytest.mark.asyncio +async def test_hook_retry_with_throttle_exception(alist, mock_sleep): + """Test that hook retry works alongside existing throttle retry.""" + + class CustomException(Exception): + pass + + model = MagicMock() + model.stream.side_effect = [ + CustomException("Custom error"), + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Success after mixed retries"}], + }, + ] + ).stream([]), + ] + + async def handle_after_model_call(event: AfterModelCallEvent): + if event.exception and isinstance(event.exception, CustomException): + event.retry = True + + agent = Agent(model=model) + agent.hooks.add_callback(AfterModelCallEvent, handle_after_model_call) + + result = agent("Test mixed retries") + + # Should succeed after: custom retry + 2 throttle retries + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Success after mixed retries" From b5d9468df40de39030ad81d4acdabda0e17b0968 Mon Sep 17 00:00:00 2001 From: Josh Samuel <3156090+jsamuel1@users.noreply.github.com> Date: Sat, 3 Jan 2026 02:27:05 +1100 Subject: [PATCH 037/279] fix: emit deprecation warning only when deprecated aliases are accessed (#1380) Previously, the deprecation warning was emitted at module import time, which triggered whenever `strands` was imported because other modules import from `experimental.hooks`. Changed to use `__getattr__` to lazily emit the warning only when the deprecated aliases (BeforeToolInvocationEvent, AfterToolInvocationEvent, BeforeModelInvocationEvent, AfterModelInvocationEvent) are actually accessed. Fixes #1236 --------- Co-authored-by: Mackenzie Zastrow --- src/strands/experimental/hooks/__init__.py | 19 ++++++++--- src/strands/experimental/hooks/events.py | 32 ++++++++++++------- .../experimental/hooks/test_hook_aliases.py | 18 ++++++----- 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index c76b57ea4..f2219bf7b 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -1,19 +1,28 @@ """Experimental hook functionality that has not yet reached stability.""" +from typing import Any + from .events import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, + BidiAfterConnectionRestartEvent, BidiAfterInvocationEvent, BidiAfterToolCallEvent, BidiAgentInitializedEvent, + BidiBeforeConnectionRestartEvent, BidiBeforeInvocationEvent, BidiBeforeToolCallEvent, BidiInterruptionEvent, BidiMessageAddedEvent, ) +# Deprecated aliases are accessed via __getattr__ to emit warnings only on use + + +def __getattr__(name: str) -> Any: + from . import events + + return getattr(events, name) + + __all__ = [ "BeforeToolInvocationEvent", "AfterToolInvocationEvent", @@ -27,4 +36,6 @@ "BidiBeforeToolCallEvent", "BidiAfterToolCallEvent", "BidiInterruptionEvent", + "BidiBeforeConnectionRestartEvent", + "BidiAfterConnectionRestartEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 8a8d80629..081190af3 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -5,7 +5,7 @@ import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent from ...hooks.registry import BaseHookEvent @@ -16,17 +16,25 @@ from ..bidi.agent.agent import BidiAgent from ..bidi.models import BidiModelTimeoutError -warnings.warn( - "BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent are no longer experimental." - "Import from strands.hooks instead.", - DeprecationWarning, - stacklevel=2, -) - -BeforeToolInvocationEvent: TypeAlias = BeforeToolCallEvent -AfterToolInvocationEvent: TypeAlias = AfterToolCallEvent -BeforeModelInvocationEvent: TypeAlias = BeforeModelCallEvent -AfterModelInvocationEvent: TypeAlias = AfterModelCallEvent +# Deprecated aliases - warning emitted on access via __getattr__ +_DEPRECATED_ALIASES = { + "BeforeToolInvocationEvent": BeforeToolCallEvent, + "AfterToolInvocationEvent": AfterToolCallEvent, + "BeforeModelInvocationEvent": BeforeModelCallEvent, + "AfterModelInvocationEvent": AfterModelCallEvent, +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_ALIASES: + warnings.warn( + f"{name} has been moved to production with an updated name. " + f"Use {_DEPRECATED_ALIASES[name].__name__} from strands.hooks instead.", + DeprecationWarning, + stacklevel=2, + ) + return _DEPRECATED_ALIASES[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") # BidiAgent Hook Events diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index f4899f2ab..2da8a6f90 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -112,18 +112,20 @@ def experimental_callback(event: BeforeToolInvocationEvent): assert received_event is test_event -def test_deprecation_warning_on_import(captured_warnings): - """Verify that importing from experimental module emits deprecation warning.""" +def test_deprecation_warning_on_access(captured_warnings): + """Verify that accessing deprecated aliases emits deprecation warning.""" + import strands.experimental.hooks.events as events_module - module = sys.modules.get("strands.experimental.hooks.events") - if module: - importlib.reload(module) - else: - importlib.import_module("strands.experimental.hooks.events") + # Clear any existing warnings + captured_warnings.clear() + + # Access a deprecated alias - this should trigger the warning + _ = events_module.BeforeToolInvocationEvent assert len(captured_warnings) == 1 assert issubclass(captured_warnings[0].category, DeprecationWarning) - assert "are no longer experimental" in str(captured_warnings[0].message) + assert "BeforeToolInvocationEvent" in str(captured_warnings[0].message) + assert "BeforeToolCallEvent" in str(captured_warnings[0].message) def test_deprecation_warning_on_import_only_for_experimental(captured_warnings): From 695ca66541853e7983b9b39c8649801cd7018875 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 5 Jan 2026 19:21:57 +0200 Subject: [PATCH 038/279] docs: update github agent action to reference AGENT_SESSIONS_BUCKET secret (#1418) --- .github/actions/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/README.md b/.github/actions/README.md index a3ec3fa2d..6559462cb 100644 --- a/.github/actions/README.md +++ b/.github/actions/README.md @@ -198,7 +198,7 @@ Your IAM role must have these permissions in order to execute: 3. **Create S3 Bucket** for session storage 4. **Add GitHub Secrets**: - `AWS_ROLE_ARN`: The created role ARN - - `STRANDS_SESSION_BUCKET`: The S3 bucket name + - `AGENT_SESSIONS_BUCKET`: The S3 bucket name ## Security From 50e5e74c53de58c7de2b9ecbfbca488fcdbf456e Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:47:26 -0500 Subject: [PATCH 039/279] feat: provide extra command content as the the prompt to the agent (#1419) Previously triggering the agent would always provide the prompt of "review and continue" to the agent; this meant that if you gave the agent explicit commands in the comment it wouldn't necessarily receive/act on those. For example: /strands you didn't do X, please do it It would not actually receive the extra text; this updates it so that everything after the "strands command" is added as the prompt, defaulting to "review and continue" if non is provided --------- Co-authored-by: Mackenzie Zastrow --- .../actions/strands-agent-runner/action.yml | 2 +- .github/scripts/javascript/process-input.cjs | 20 +++++++++++++++++-- .github/scripts/python/agent_runner.py | 13 ++++++------ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/.github/actions/strands-agent-runner/action.yml b/.github/actions/strands-agent-runner/action.yml index 6d4c2d7fb..d0e93effe 100644 --- a/.github/actions/strands-agent-runner/action.yml +++ b/.github/actions/strands-agent-runner/action.yml @@ -149,7 +149,7 @@ runs: STRANDS_TOOL_CONSOLE_MODE: 'enabled' BYPASS_TOOL_CONSENT: 'true' run: | - uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py "$INPUT_TASK" + uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py - name: Capture repository state shell: bash diff --git a/.github/scripts/javascript/process-input.cjs b/.github/scripts/javascript/process-input.cjs index b7ed29263..395e37b64 100644 --- a/.github/scripts/javascript/process-input.cjs +++ b/.github/scripts/javascript/process-input.cjs @@ -8,9 +8,10 @@ async function getIssueInfo(github, context, inputs) { const issueId = context.eventName === 'workflow_dispatch' ? inputs.issue_id : context.payload.issue.number.toString(); + const commentBody = context.payload.comment?.body || ''; const command = context.eventName === 'workflow_dispatch' ? inputs.command - : (context.payload.comment.body.match(/^\/strands\s*(.*?)$/m)?.[1]?.trim() || ''); + : (commentBody.startsWith('/strands') ? commentBody.slice('/strands'.length).trim() : ''); console.log(`Event: ${context.eventName}, Issue ID: ${issueId}, Command: "${command}"`); @@ -76,10 +77,25 @@ function buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs) const scriptFile = scriptFiles[mode] || scriptFiles['refiner']; const systemPrompt = fs.readFileSync(scriptFile, 'utf8'); + // Extract the user's feedback/instructions after the mode keyword + // e.g., "release-notes Move #123 to Major Features" -> "Move #123 to Major Features" + const modeKeywords = { + 'release-notes': /^(?:release-notes|release notes)\s*/i, + 'implementer': /^implement\s*/i, + 'refiner': /^refine\s*/i + }; + + const modePattern = modeKeywords[mode]; + const userFeedback = modePattern ? command.replace(modePattern, '').trim() : command.trim(); + let prompt = (isPullRequest) ? 'The pull request id is:' : 'The issue id is:'; - prompt += `${issueId}\n${command}\nreview and continue`; + prompt += `${issueId}\n`; + + // If there's any user feedback beyond the command keyword, include it as the main instruction, + // otherwise default to "review and continue" + prompt += userFeedback || 'review and continue'; return { sessionId, systemPrompt, prompt }; } diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py index db10ceadb..9d92c2ac4 100644 --- a/.github/scripts/python/agent_runner.py +++ b/.github/scripts/python/agent_runner.py @@ -142,13 +142,12 @@ def run_agent(query: str): def main() -> None: """Main entry point for the agent runner.""" try: - # Read task from command line arguments - if len(sys.argv) < 2: - raise ValueError("Task argument is required") - - task = " ".join(sys.argv[1:]) - if not task.strip(): - raise ValueError("Task cannot be empty") + # Prefer INPUT_TASK env var (avoids shell escaping issues), fall back to CLI args + task = os.getenv("INPUT_TASK", "").strip() + if not task and len(sys.argv) > 1: + task = " ".join(sys.argv[1:]).strip() + if not task: + raise ValueError("Task is required (via INPUT_TASK env var or CLI argument)") print(f"🤖 Running agent with task: {task}") run_agent(task) From 3bc34acc76373c6465354fa57ced0498342cde32 Mon Sep 17 00:00:00 2001 From: AI Ape Wisdom Date: Wed, 7 Jan 2026 05:12:23 +0800 Subject: [PATCH 040/279] [FEATURE] add MCP resource operations in MCP Tools (#1117) * feat(tools): Add MCP resource operations * feat(tools): Add MCP resource operations * tests: add integ tests for mcp resources * fix: broken merge --------- Co-authored-by: Dean Schmigelski --- src/strands/tools/mcp/mcp_client.py | 87 +++++++++++- tests/strands/tools/mcp/test_mcp_client.py | 154 ++++++++++++++++++++- tests_integ/mcp/echo_server.py | 19 +++ tests_integ/mcp/test_mcp_resources.py | 130 +++++++++++++++++ 4 files changed, 388 insertions(+), 2 deletions(-) create mode 100644 tests_integ/mcp/test_mcp_resources.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 6ce591bc5..37b99d021 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -21,11 +21,20 @@ import anyio from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT -from mcp.types import BlobResourceContents, GetPromptResult, ListPromptsResult, TextResourceContents +from mcp.types import ( + BlobResourceContents, + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + TextResourceContents, +) from mcp.types import CallToolResult as MCPCallToolResult from mcp.types import EmbeddedResource as MCPEmbeddedResource from mcp.types import ImageContent as MCPImageContent from mcp.types import TextContent as MCPTextContent +from pydantic import AnyUrl from typing_extensions import Protocol, TypedDict from ...experimental.tools import ToolProvider @@ -449,6 +458,82 @@ async def _get_prompt_async() -> GetPromptResult: return get_prompt_result + def list_resources_sync(self, pagination_token: Optional[str] = None) -> ListResourcesResult: + """Synchronously retrieves the list of available resources from the MCP server. + + This method calls the asynchronous list_resources method on the MCP session + and returns the raw ListResourcesResult with pagination support. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourcesResult: The raw MCP response containing resources and pagination info + """ + self._log_debug_with_thread("listing MCP resources synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resources_async() -> ListResourcesResult: + return await cast(ClientSession, self._background_thread_session).list_resources(cursor=pagination_token) + + list_resources_result: ListResourcesResult = self._invoke_on_background_thread(_list_resources_async()).result() + self._log_debug_with_thread("received %d resources from MCP server", len(list_resources_result.resources)) + + return list_resources_result + + def read_resource_sync(self, uri: AnyUrl | str) -> ReadResourceResult: + """Synchronously reads a resource from the MCP server. + + Args: + uri: The URI of the resource to read + + Returns: + ReadResourceResult: The resource content from the MCP server + """ + self._log_debug_with_thread("reading MCP resource synchronously: %s", uri) + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _read_resource_async() -> ReadResourceResult: + # Convert string to AnyUrl if needed + resource_uri = AnyUrl(uri) if isinstance(uri, str) else uri + return await cast(ClientSession, self._background_thread_session).read_resource(resource_uri) + + read_resource_result: ReadResourceResult = self._invoke_on_background_thread(_read_resource_async()).result() + self._log_debug_with_thread("received resource content from MCP server") + + return read_resource_result + + def list_resource_templates_sync(self, pagination_token: Optional[str] = None) -> ListResourceTemplatesResult: + """Synchronously retrieves the list of available resource templates from the MCP server. + + Resource templates define URI patterns that can be used to access resources dynamically. + + Args: + pagination_token: Optional token for pagination + + Returns: + ListResourceTemplatesResult: The raw MCP response containing resource templates and pagination info + """ + self._log_debug_with_thread("listing MCP resource templates synchronously") + if not self._is_session_active(): + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) + + async def _list_resource_templates_async() -> ListResourceTemplatesResult: + return await cast(ClientSession, self._background_thread_session).list_resource_templates( + cursor=pagination_token + ) + + list_resource_templates_result: ListResourceTemplatesResult = self._invoke_on_background_thread( + _list_resource_templates_async() + ).result() + self._log_debug_with_thread( + "received %d resource templates from MCP server", len(list_resource_templates_result.resourceTemplates) + ) + + return list_resource_templates_result + def call_tool_sync( self, tool_use_id: str, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f5040de1b..35f11f47f 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -5,9 +5,21 @@ import pytest from mcp import ListToolsResult from mcp.types import CallToolResult as MCPCallToolResult -from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage +from mcp.types import ( + GetPromptResult, + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + Prompt, + PromptMessage, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool +from pydantic import AnyUrl from strands.tools.mcp import MCPClient from strands.tools.mcp.mcp_types import MCPToolResult @@ -772,3 +784,143 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se assert result["metadata"] == metadata assert "structuredContent" in result assert result["structuredContent"] == structured_content + + +# Resource Tests - Sync Methods + + +def test_list_resources_sync(mock_transport, mock_session): + """Test that list_resources_sync correctly retrieves resources.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync() + + mock_session.list_resources.assert_called_once_with(cursor=None) + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert str(result.resources[0].uri) == "file://documents/test.txt" + assert result.nextCursor is None + + +def test_list_resources_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resources_sync correctly passes pagination token and returns next cursor.""" + mock_resource = Resource( + uri=AnyUrl("file://documents/test.txt"), name="test.txt", description="A test document", mimeType="text/plain" + ) + mock_session.list_resources.return_value = ListResourcesResult(resources=[mock_resource], nextCursor="next_page") + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resources_sync(pagination_token="current_page") + + mock_session.list_resources.assert_called_once_with(cursor="current_page") + assert len(result.resources) == 1 + assert result.resources[0].name == "test.txt" + assert result.nextCursor == "next_page" + + +def test_list_resources_sync_session_not_active(): + """Test that list_resources_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resources_sync() + + +def test_read_resource_sync(mock_transport, mock_session): + """Test that read_resource_sync correctly reads a resource.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.read_resource_sync("file://documents/test.txt") + + # Verify the session method was called + mock_session.read_resource.assert_called_once() + # Check the URI argument (it will be wrapped as AnyUrl) + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_with_anyurl(mock_transport, mock_session): + """Test that read_resource_sync correctly handles AnyUrl input.""" + mock_content = TextResourceContents( + uri=AnyUrl("file://documents/test.txt"), text="Resource content", mimeType="text/plain" + ) + mock_session.read_resource.return_value = ReadResourceResult(contents=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + uri = AnyUrl("file://documents/test.txt") + result = client.read_resource_sync(uri) + + mock_session.read_resource.assert_called_once() + call_args = mock_session.read_resource.call_args[0] + assert str(call_args[0]) == "file://documents/test.txt" + + assert len(result.contents) == 1 + assert result.contents[0].text == "Resource content" + + +def test_read_resource_sync_session_not_active(): + """Test that read_resource_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.read_resource_sync("file://documents/test.txt") + + +def test_list_resource_templates_sync(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly retrieves resource templates.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult(resourceTemplates=[mock_template]) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync() + + mock_session.list_resource_templates.assert_called_once_with(cursor=None) + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.resourceTemplates[0].uriTemplate == "file://documents/{name}" + assert result.nextCursor is None + + +def test_list_resource_templates_sync_with_pagination_token(mock_transport, mock_session): + """Test that list_resource_templates_sync correctly passes pagination token and returns next cursor.""" + mock_template = ResourceTemplate( + uriTemplate="file://documents/{name}", + name="document_template", + description="Template for documents", + mimeType="text/plain", + ) + mock_session.list_resource_templates.return_value = ListResourceTemplatesResult( + resourceTemplates=[mock_template], nextCursor="next_page" + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.list_resource_templates_sync(pagination_token="current_page") + + mock_session.list_resource_templates.assert_called_once_with(cursor="current_page") + assert len(result.resourceTemplates) == 1 + assert result.resourceTemplates[0].name == "document_template" + assert result.nextCursor == "next_page" + + +def test_list_resource_templates_sync_session_not_active(): + """Test that list_resource_templates_sync raises an error when session is not active.""" + client = MCPClient(MagicMock()) + + with pytest.raises(MCPClientInitializationError, match="client session is not running"): + client.list_resource_templates_sync() diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index a23a87b5c..151f913d6 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -16,12 +16,15 @@ """ import base64 +import json from typing import Literal from mcp.server import FastMCP from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents from pydantic import BaseModel +TEST_IMAGE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + class EchoResponse(BaseModel): """Response model for echo with structured content.""" @@ -102,6 +105,22 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): ) ] + # Resources + @mcp.resource("test://static-text") + def static_text_resource() -> str: + """A static text resource for testing""" + return "This is the content of the static text resource." + + @mcp.resource("test://static-binary") + def static_binary_resource() -> bytes: + """A static binary resource (image) for testing""" + return base64.b64decode(TEST_IMAGE_BASE64) + + @mcp.resource("test://template/{id}/data") + def template_resource(id: str) -> str: + """A resource template with parameter substitution""" + return json.dumps({"id": id, "templateTest": True, "data": f"Data for ID: {id}"}) + mcp.run(transport="stdio") diff --git a/tests_integ/mcp/test_mcp_resources.py b/tests_integ/mcp/test_mcp_resources.py new file mode 100644 index 000000000..dccf3b808 --- /dev/null +++ b/tests_integ/mcp/test_mcp_resources.py @@ -0,0 +1,130 @@ +""" +Integration tests for MCP client resource functionality. + +This module tests the resource-related methods in MCPClient: +- list_resources_sync() +- read_resource_sync() +- list_resource_templates_sync() + +The tests use the echo server which has been extended with resource functionality. +""" + +import base64 +import json + +import pytest +from mcp import StdioServerParameters, stdio_client +from mcp.shared.exceptions import McpError +from mcp.types import BlobResourceContents, TextResourceContents +from pydantic import AnyUrl + +from strands.tools.mcp.mcp_client import MCPClient + + +def test_mcp_resources_list_and_read(): + """Test listing and reading various types of resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resources_sync + resources_result = mcp_client.list_resources_sync() + assert len(resources_result.resources) >= 2 # At least our 2 static resources + + # Verify resource URIs exist (only static resources, not templates) + resource_uris = [str(r.uri) for r in resources_result.resources] + assert "test://static-text" in resource_uris + assert "test://static-binary" in resource_uris + # Template resources are not listed in static resources + + # Test reading text resource + text_resource = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource.contents) == 1 + content = text_resource.contents[0] + assert isinstance(content, TextResourceContents) + assert "This is the content of the static text resource." in content.text + + # Test reading binary resource + binary_resource = mcp_client.read_resource_sync("test://static-binary") + assert len(binary_resource.contents) == 1 + binary_content = binary_resource.contents[0] + assert isinstance(binary_content, BlobResourceContents) + # Verify it's valid base64 encoded data + decoded_data = base64.b64decode(binary_content.blob) + assert len(decoded_data) > 0 + + +def test_mcp_resources_templates(): + """Test listing resource templates and reading from template resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test list_resource_templates_sync + templates_result = mcp_client.list_resource_templates_sync() + assert len(templates_result.resourceTemplates) >= 1 + + # Verify template URIs exist + template_uris = [t.uriTemplate for t in templates_result.resourceTemplates] + assert "test://template/{id}/data" in template_uris + + # Test reading from template resource + template_resource = mcp_client.read_resource_sync("test://template/123/data") + assert len(template_resource.contents) == 1 + template_content = template_resource.contents[0] + assert isinstance(template_content, TextResourceContents) + + # Parse the JSON response + parsed_json = json.loads(template_content.text) + assert parsed_json["id"] == "123" + assert parsed_json["templateTest"] is True + assert "Data for ID: 123" in parsed_json["data"] + + +def test_mcp_resources_pagination(): + """Test pagination support for resources.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with pagination token (should work even if server doesn't implement pagination) + resources_result = mcp_client.list_resources_sync(pagination_token=None) + assert len(resources_result.resources) >= 0 + + # Test resource templates pagination + templates_result = mcp_client.list_resource_templates_sync(pagination_token=None) + assert len(templates_result.resourceTemplates) >= 0 + + +def test_mcp_resources_error_handling(): + """Test error handling for resource operations.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test reading non-existent resource + with pytest.raises(McpError, match="Unknown resource"): + mcp_client.read_resource_sync("test://nonexistent") + + +def test_mcp_resources_uri_types(): + """Test that both string and AnyUrl types work for read_resource_sync.""" + mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with mcp_client: + # Test with string URI + text_resource_str = mcp_client.read_resource_sync("test://static-text") + assert len(text_resource_str.contents) == 1 + + # Test with AnyUrl URI + text_resource_url = mcp_client.read_resource_sync(AnyUrl("test://static-text")) + assert len(text_resource_url.contents) == 1 + + # Both should return the same content + assert text_resource_str.contents[0].text == text_resource_url.contents[0].text From 514f40243b280f0645f31a0c48a8bac9e2f1354f Mon Sep 17 00:00:00 2001 From: mehtarac Date: Tue, 6 Jan 2026 13:26:31 -0800 Subject: [PATCH 041/279] fix: import errors for models with optional imports (#1384) * fix: import errors for models with optional imports * Addressed comments: added return type, changed error message * Addressed comments: updated imports --- src/strands/models/__init__.py | 57 +++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index ead290a35..d5f88d09a 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -3,8 +3,63 @@ This package includes an abstract base Model class along with concrete implementations for specific providers. """ +from typing import Any + from . import bedrock, model from .bedrock import BedrockModel from .model import Model -__all__ = ["bedrock", "model", "BedrockModel", "Model"] +__all__ = [ + "bedrock", + "model", + "BedrockModel", + "Model", +] + + +def __getattr__(name: str) -> Any: + """Lazy load model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "AnthropicModel": + from .anthropic import AnthropicModel + + return AnthropicModel + if name == "GeminiModel": + from .gemini import GeminiModel + + return GeminiModel + if name == "LiteLLMModel": + from .litellm import LiteLLMModel + + return LiteLLMModel + if name == "LlamaAPIModel": + from .llamaapi import LlamaAPIModel + + return LlamaAPIModel + if name == "LlamaCppModel": + from .llamacpp import LlamaCppModel + + return LlamaCppModel + if name == "MistralModel": + from .mistral import MistralModel + + return MistralModel + if name == "OllamaModel": + from .ollama import OllamaModel + + return OllamaModel + if name == "OpenAIModel": + from .openai import OpenAIModel + + return OpenAIModel + if name == "SageMakerAIModel": + from .sagemaker import SageMakerAIModel + + return SageMakerAIModel + if name == "WriterModel": + from .writer import WriterModel + + return WriterModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") From 9fd22d18adbb8f55e0a3cf3450db8d39492d4ea2 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Tue, 6 Jan 2026 13:26:53 -0800 Subject: [PATCH 042/279] add BidiGeminiLiveModel and BidiOpenAIRealtimeModel to the init (#1383) * add BidiGeminiLiveModel and BidiOpenAIRealtimeModel to the init * Address comments - re-word error message, add return type * Addressed comments: updated imports --- .../experimental/bidi/models/__init__.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index cc62c9987..6e5817046 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,5 +1,7 @@ """Bidirectional model interfaces and implementations.""" +from typing import Any + from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel @@ -8,3 +10,22 @@ "BidiModelTimeoutError", "BidiNovaSonicModel", ] + + +def __getattr__(name: str) -> Any: + """ + Lazy load bidi model implementations only when accessed. + + This defers the import of optional dependencies until actually needed: + - BidiGeminiLiveModel requires google-generativeai (lazy loaded) + - BidiOpenAIRealtimeModel requires openai (lazy loaded) + """ + if name == "BidiGeminiLiveModel": + from .gemini_live import BidiGeminiLiveModel + + return BidiGeminiLiveModel + if name == "BidiOpenAIRealtimeModel": + from .openai_realtime import BidiOpenAIRealtimeModel + + return BidiOpenAIRealtimeModel + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") From 2b1cf6bde77f701f9c6185651ea932e564607e09 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 7 Jan 2026 10:35:47 -0500 Subject: [PATCH 043/279] bidi - async - remove cancelling call (#1357) --- .../experimental/bidi/_async/_task_group.py | 38 +++++++++++-------- .../bidi/_async/test_task_group.py | 16 +++++++- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py index 26c67326d..33cf63dca 100644 --- a/src/strands/experimental/bidi/_async/_task_group.py +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -6,17 +6,17 @@ """ import asyncio -from typing import Any, Coroutine +from typing import Any, Coroutine, cast class _TaskGroup: """Shim of asyncio.TaskGroup for use in Python 3.10. Attributes: - _tasks: List of tasks in group. + _tasks: Set of tasks in group. """ - _tasks: list[asyncio.Task] + _tasks: set[asyncio.Task] def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: """Create an async task and add to group. @@ -25,12 +25,12 @@ def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: The created task. """ task = asyncio.create_task(coro) - self._tasks.append(task) + self._tasks.add(task) return task async def __aenter__(self) -> "_TaskGroup": """Setup self managed task group context.""" - self._tasks = [] + self._tasks = set() return self async def __aexit__(self, *_: Any) -> None: @@ -42,20 +42,28 @@ async def __aexit__(self, *_: Any) -> None: - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. """ try: - await asyncio.gather(*self._tasks) + pending_tasks = self._tasks + while pending_tasks: + done_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_EXCEPTION) - except (Exception, asyncio.CancelledError) as error: + if any(exception := done_task.exception() for done_task in done_tasks if not done_task.cancelled()): + break + + else: # all tasks completed/cancelled successfully + return + + for pending_task in pending_tasks: + pending_task.cancel() + + await asyncio.gather(*pending_tasks, return_exceptions=True) + raise cast(BaseException, exception) + + except asyncio.CancelledError: # context itself was cancelled for task in self._tasks: task.cancel() await asyncio.gather(*self._tasks, return_exceptions=True) - - if not isinstance(error, asyncio.CancelledError): - raise - - context_task = asyncio.current_task() - if context_task and context_task.cancelling() > 0: # context itself was cancelled - raise + raise finally: - self._tasks = [] + self._tasks = set() diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py index 23ff821f9..b9a30ef5b 100644 --- a/tests/strands/experimental/bidi/_async/test_task_group.py +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -17,7 +17,7 @@ async def test_task_group__aexit__(): @pytest.mark.asyncio -async def test_task_group__aexit__exception(): +async def test_task_group__aexit__task_exception(): wait_event = asyncio.Event() async def wait(): await wait_event.wait() @@ -35,7 +35,19 @@ async def fail(): @pytest.mark.asyncio -async def test_task_group__aexit__cancelled(): +async def test_task_group__aexit__task_cancelled(): + async def wait(): + asyncio.current_task().cancel() + await asyncio.sleep(0) + + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + + assert wait_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__context_cancelled(): wait_event = asyncio.Event() async def wait(): await wait_event.wait() From 08bf5638ee9a5cdf78b0eeabcf40a6a33f6c7659 Mon Sep 17 00:00:00 2001 From: Aleksei Iancheruk <113924163+aiancheruk@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:54:00 +0100 Subject: [PATCH 044/279] feat(bedrock): add guardrail_latest_message option (#1224) * feat(bedrock): add guardrail_last_turn_only option * fix(bedrock): include assistant response in guardrail_last_turn_only context * fix: optimize code * feat: rewrtie the logic, include last user message in guardContent when feature flag is true * fix: remove uncessary integ tests and simplify guardrail logic * fix: rename feature flag, remove uncessary tests,add image to guardcontent block * fix: simplify logic and make tests more reliable --------- Co-authored-by: Aleksei Iancheruk Co-authored-by: Jack Yuan --- src/strands/models/bedrock.py | 22 +++++++++++-- tests/strands/models/test_bedrock.py | 44 +++++++++++++++++++++++++ tests_integ/test_bedrock_guardrails.py | 45 ++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 08d8f400c..8e1558ca7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -82,6 +82,8 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message. guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False. guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message. + guardrail_latest_message: Flag to send only the lastest user message to guardrails. + Defaults to False. max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") include_tool_result_status: Flag to include status field in tool results. @@ -105,6 +107,7 @@ class BedrockConfig(TypedDict, total=False): guardrail_redact_input_message: Optional[str] guardrail_redact_output: Optional[bool] guardrail_redact_output_message: Optional[str] + guardrail_latest_message: Optional[bool] max_tokens: Optional[int] model_id: str include_tool_result_status: Optional[Literal["auto"] | bool] @@ -199,7 +202,6 @@ def _format_request( Args: messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks to provide context to the model. @@ -302,6 +304,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Filtering out SDK_UNKNOWN_MEMBER content blocks - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API + - Optionally wrapping the last user message in guardrailConverseContent blocks Args: messages: List of messages to format @@ -321,7 +324,9 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: filtered_unknown_members = False dropped_deepseek_reasoning_content = False - for message in messages: + guardrail_latest_message = self.config.get("guardrail_latest_message", False) + + for idx, message in enumerate(messages): cleaned_content: list[dict[str, Any]] = [] for content_block in message["content"]: @@ -338,6 +343,19 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Format content blocks for Bedrock API compatibility formatted_content = self._format_request_message_content(content_block) + + # Wrap text or image content in guardrailContent if this is the last user message + if ( + guardrail_latest_message + and idx == len(messages) - 1 + and message["role"] == "user" + and ("text" in formatted_content or "image" in formatted_content) + ): + if "text" in formatted_content: + formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}} + elif "image" in formatted_content: + formatted_content = {"guardContent": {"image": formatted_content["image"]}} + cleaned_content.append(formatted_content) # Create new message with cleaned content (skip if empty) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 33be44b1b..7697c5e03 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2196,3 +2196,47 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " "with the location fields nested inside." ) + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message(model): + """Test that guardrail_latest_message wraps the latest user message with text and image.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + { + "role": "user", + "content": [ + {"text": "Look at this image"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # All messages should be in the request + assert len(formatted_messages) == 3 + + # First user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Assistant message should NOT be wrapped + assert "text" in formatted_messages[1]["content"][0] + assert formatted_messages[1]["content"][0]["text"] == "First response" + + # Latest user message text should be wrapped + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Look at this image" + + # Latest user message image should also be wrapped + assert "guardContent" in formatted_messages[2]["content"][1] + assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 37fa6028c..058597026 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -289,6 +289,51 @@ def list_users() -> str: assert tool_result["content"][0]["text"] == INPUT_REDACT_MESSAGE +def test_guardrail_latest_message(boto_session, bedrock_guardrail, yellow_img): + """Test that guardrail_latest_user_message wraps both text and image in the latest user message.""" + bedrock_model = BedrockModel( + guardrail_id=bedrock_guardrail, + guardrail_version="DRAFT", + guardrail_latest_message=True, + boto_session=boto_session, + ) + + # Create agent with valid content + agent1 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + ], + ) + + response = agent1("What do you see?") + assert response.stop_reason != "guardrail_intervened" + + # Create agent with multimodal content in latest user message + agent2 = Agent( + model=bedrock_model, + system_prompt="You are a helpful assistant.", + callback_handler=None, + messages=[ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "Hello!"}]}, + { + "role": "user", + "content": [ + {"text": "CACTUS"}, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + ], + }, + ], + ) + + response = agent2("What do you see?") + assert response.stop_reason == "guardrail_intervened" + + def test_guardrail_input_intervention_properly_redacts_in_session(boto_session, bedrock_guardrail, temp_dir): bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, From 1e27d79bd8e6e13f7a77ccd391e4a581415ab90d Mon Sep 17 00:00:00 2001 From: Evan Mattiza Date: Wed, 7 Jan 2026 11:44:28 -0600 Subject: [PATCH 045/279] fix(gemini): Gemini UnboundLocal Exception raised during stream (#1420) --- src/strands/models/gemini.py | 5 ++++- tests/strands/models/test_gemini.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index cf7cc604a..45f7f4e18 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -426,6 +426,8 @@ async def stream( yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) tool_used = False + candidate = None + event = None async for event in response: candidates = event.candidates candidate = candidates[0] if candidates else None @@ -455,7 +457,8 @@ async def stream( "data": "TOOL_USE" if tool_used else (candidate.finish_reason if candidate else "STOP"), } ) - yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) + if event: + yield self._format_chunk({"chunk_type": "metadata", "data": event.usage_metadata}) except genai.errors.ClientError as error: if not error.message: diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index c552a892a..08be9188d 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -566,6 +566,25 @@ async def test_stream_response_none_candidates(gemini_client, model, messages, a assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_empty_stream(gemini_client, model, messages, agenerator, alist): + """Test that empty stream doesn't raise UnboundLocalError. + + When the stream yields no events, the candidate variable must be initialized + to None to avoid UnboundLocalError when referenced in message_stop chunk. + """ + gemini_client.aio.models.generate_content_stream.return_value = agenerator([]) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_throttled_exception(gemini_client, model, messages): gemini_client.aio.models.generate_content_stream.side_effect = genai.errors.ClientError( From 2f04bc0f9c786e6afa0837819d55accaa68b6896 Mon Sep 17 00:00:00 2001 From: schleidl Date: Thu, 8 Jan 2026 16:50:24 +0100 Subject: [PATCH 046/279] feat(litellm): handle litellm non streaming responses (#512) --------- Co-authored-by: Daniel Schleicher Co-authored-by: Dean Schmigelski --- src/strands/models/litellm.py | 253 ++++++++++++++++------ tests/strands/models/test_litellm.py | 255 ++++++++++++++++++++++- tests_integ/models/test_model_litellm.py | 30 ++- 3 files changed, 461 insertions(+), 77 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..c120b0eda 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -269,75 +269,29 @@ async def stream( ) logger.debug("request=<%s>", request) - logger.debug("invoking model") - try: - if kwargs.get("stream") is False: - raise ValueError("stream parameter cannot be explicitly set to False") - response = await litellm.acompletion(**self.client_args, **request) - except ContextWindowExceededError as e: - logger.warning("litellm client raised context window overflow") - raise ContextWindowOverflowException(e) from e + # Check if streaming is disabled in the params + config = self.get_config() + params = config.get("params") or {} + is_streaming = params.get("stream", True) - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) + litellm_request = {**request} - tool_calls: dict[int, list[Any]] = {} - data_type: str | None = None + litellm_request["stream"] = is_streaming - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] + logger.debug("invoking model with stream=%s", litellm_request.get("stream")) - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - chunks, data_type = self._stream_switch_content("reasoning_content", data_type) - for chunk in chunks: + try: + if is_streaming: + async for chunk in self._handle_streaming_response(litellm_request): yield chunk - - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": data_type, - "data": choice.delta.reasoning_content, - } - ) - - if choice.delta.content: - chunks, data_type = self._stream_switch_content("text", data_type) - for chunk in chunks: + else: + async for chunk in self._handle_non_streaming_response(litellm_request): yield chunk + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} - ) - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - if data_type: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) - break - - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event - - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) - - logger.debug("finished streaming response from model") + logger.debug("finished processing response from model") @override async def structured_output( @@ -422,6 +376,181 @@ async def _structured_output_using_tool( except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e + async def _process_choice_content( + self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True + ) -> AsyncGenerator[tuple[str | None, StreamEvent], None]: + """Process content from a choice object (streaming or non-streaming). + + Args: + choice: The choice object from the response. + data_type: Current data type being processed. + tool_calls: Dictionary to collect tool calls. + is_streaming: Whether this is from a streaming response. + + Yields: + Tuples of (updated_data_type, stream_event). + """ + # Get the content source - this is the only difference between streaming/non-streaming + # We use duck typing here: both choice.delta and choice.message have the same interface + # (reasoning_content, content, tool_calls attributes) but different object structures + content_source = choice.delta if is_streaming else choice.message + + # Process reasoning content + if hasattr(content_source, "reasoning_content") and content_source.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": content_source.reasoning_content, + } + ) + yield data_type, chunk + + # Process text content + if hasattr(content_source, "content") and content_source.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": content_source.content, + } + ) + yield data_type, chunk + + # Process tool calls + if hasattr(content_source, "tool_calls") and content_source.tool_calls: + if is_streaming: + # Streaming: tool calls have index attribute for out-of-order delivery + for tool_call in content_source.tool_calls: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + else: + # Non-streaming: tool calls arrive in order, use enumerated index + for i, tool_call in enumerate(content_source.tool_calls): + tool_calls.setdefault(i, []).append(tool_call) + + async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]: + """Process and yield tool call events. + + Args: + tool_calls: Dictionary of tool calls indexed by their position. + + Yields: + Formatted tool call chunks. + """ + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + async def _handle_non_streaming_response( + self, litellm_request: dict[str, Any] + ) -> AsyncGenerator[StreamEvent, None]: + """Handle non-streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got non-streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] + + if hasattr(choice, "message") and choice.message: + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=False + ): + data_type = updated_data_type + yield chunk + + if hasattr(choice, "finish_reason"): + finish_reason = choice.finish_reason + + # Stop the current content block if we have one + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Add usage information if available + if hasattr(response, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) + + async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + """Handle streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + # For streaming, use the streaming API + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=True + ): + data_type = updated_data_type + yield chunk + + if choice.finish_reason: + finish_reason = choice.finish_reason + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + break + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + + logger.debug("finished streaming response from model") + def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 832b5c836..99df22a3f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -285,7 +285,7 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock() + mock_event_3 = unittest.mock.Mock(usage=None) mock_event_4 = unittest.mock.Mock(usage=None) litellm_acompletion.side_effect = unittest.mock.AsyncMock( @@ -408,16 +408,6 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model pass -@pytest.mark.asyncio -async def test_stream_raises_error_when_stream_is_false(model): - """Test that stream raises ValueError when stream parameter is explicitly False.""" - messages = [{"role": "user", "content": [{"text": "test"}]}] - - with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): - async for _ in model.stream(messages, stream=False): - pass - - def test_format_request_messages_with_system_prompt_content(): """Test format_request_messages with system_prompt_content parameter.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] @@ -478,3 +468,246 @@ def test_format_request_messages_cache_point_support(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alist): + """Test LiteLLM model with streaming disabled (stream=False). + + This test verifies that the LiteLLM model works correctly when streaming is disabled, + which was the issue reported in GitHub issue #477. + """ + + mock_function = unittest.mock.Mock() + mock_function.name = "calculator" + mock_function.arguments = '{"expression": "123981723 + 234982734"}' + + mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id="tool_call_id_123") + + mock_message = unittest.mock.Mock() + mock_message.content = "I'll calculate that for you" + mock_message.reasoning_content = "Let me think about this calculation" + mock_message.tool_calls = [mock_tool_call] + + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_choice.finish_reason = "tool_calls" + + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + # Create a more explicit usage mock that doesn't have cache-related attributes + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_usage.total_tokens = 30 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_response.usage = mock_usage + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + model = LiteLLMModel( + client_args={"api_key": api_key}, + model_id=model_id, + params={"stream": False}, # This is the key setting that was causing the #477 isuue + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "What is 123981723 + 234982734?"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Let me think about this calculation"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockStop": {}}, + { + "contentBlockStart": { + "start": {"toolUse": {"name": "calculator", "toolUseId": mock_message.tool_calls[0].id}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "123981723 + 234982734"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30, + }, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + + for i, (tru, exp) in enumerate(zip(tru_events, exp_events, strict=False)): + assert tru == exp, f"Event {i} mismatch: {tru} != {exp}" + + expected_request = { + "api_key": api_key, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "What is 123981723 + 234982734?", "type": "text"}]}], + "stream": False, # Verify that stream=False was passed to litellm + "stream_options": {"include_usage": True}, + "tools": [], + } + litellm_acompletion.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_path_validation(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test that we're taking the correct streaming path and validate stream parameter.""" + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=agenerator([mock_event_1, mock_event_2])) + + messages = [{"role": "user", "content": []}] + response = model.stream(messages) + + # Consume the response + await alist(response) + + # Validate that litellm.acompletion was called with the expected parameters + call_args = litellm_acompletion.call_args + assert call_args is not None, "litellm.acompletion should have been called" + + # Check if stream parameter is being set + called_kwargs = call_args.kwargs + + # Validate we're going down the streaming path (should have stream=True) + assert called_kwargs.get("stream") is True, f"Expected stream=True, got {called_kwargs.get('stream')}" + + +def test_format_request_message_content_reasoning(): + """Test formatting reasoning content.""" + content = {"reasoningContent": {"reasoningText": {"signature": "test_sig", "text": "test_thinking"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"signature": "test_sig", "thinking": "test_thinking", "type": "thinking"} + + assert result == expected + + +def test_format_request_message_content_video(): + """Test formatting video content.""" + content = {"video": {"source": {"bytes": "base64videodata"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"type": "video_url", "video_url": {"detail": "auto", "url": "base64videodata"}} + + assert result == expected + + +def test_apply_proxy_prefix_with_use_litellm_proxy(): + """Test _apply_proxy_prefix when use_litellm_proxy is True.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_already_has_prefix(): + """Test _apply_proxy_prefix when model_id already has prefix.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="litellm_proxy/openai/gpt-4") + + # Should not add another prefix + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_disabled(): + """Test _apply_proxy_prefix when use_litellm_proxy is False.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": False}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "openai/gpt-4" + + +def test_format_chunk_metadata_with_cache_tokens(): + """Test format_chunk for metadata with cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data with cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + # Mock cache-related attributes + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + mock_usage.cache_creation_input_tokens = 10 + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + assert result["metadata"]["usage"]["cacheWriteInputTokens"] == 10 + + +def test_format_chunk_metadata_without_cache_tokens(): + """Test format_chunk for metadata without cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data without cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + assert "cacheWriteInputTokens" not in result["metadata"]["usage"] + + +def test_stream_switch_content_same_type(): + """Test _stream_switch_content when data_type is the same as prev_data_type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "text") + + assert chunks == [] + assert data_type == "text" + + +def test_stream_switch_content_different_type_with_prev(): + """Test _stream_switch_content when switching from one type to another.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "reasoning_content") + + assert len(chunks) == 2 + assert chunks[0]["contentBlockStop"] == {} + assert chunks[1]["contentBlockStart"] == {"start": {}} + assert data_type == "text" + + +def test_stream_switch_content_different_type_no_prev(): + """Test _stream_switch_content when switching to a type with no previous type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", None) + + assert len(chunks) == 1 + assert chunks[0]["contentBlockStart"] == {"start": {}} + assert data_type == "text" diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index d72937641..80e21bdfd 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -14,6 +14,16 @@ def model(): return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") +@pytest.fixture +def streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": True}) + + +@pytest.fixture +def non_streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": False}) + + @pytest.fixture def tools(): @strands.tool @@ -95,15 +105,21 @@ def lower(_, value): return Color(simple_color_name="yellow") -def test_agent_invoke(agent): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_agent_invoke(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -138,14 +154,20 @@ def test_agent_invoke_reasoning(agent, model): assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] -def test_structured_output(agent, weather): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_structured_output(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_structured_output_async(agent, weather): +async def test_agent_structured_output_async(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather From 0ef228878bde27f361f4aa9ecb1d962986ee4cbf Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:19:06 +0100 Subject: [PATCH 047/279] feat(agent): introduce AgentBase Protocol as the interface for agent classes to implement (#1126) --- src/strands/__init__.py | 2 ++ src/strands/agent/__init__.py | 2 ++ src/strands/agent/agent.py | 2 +- src/strands/agent/base.py | 66 +++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 src/strands/agent/base.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..bc17497a0 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,11 +2,13 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.base import AgentBase from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", + "AgentBase", "agent", "models", "tool", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..c00623dc2 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -8,6 +8,7 @@ from .agent import Agent from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, NullConversationManager, @@ -17,6 +18,7 @@ __all__ = [ "Agent", + "AgentBase", "AgentResult", "ConversationManager", "NullConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9e726ca0b..c4ebc0b54 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -87,7 +87,7 @@ class _DefaultCallbackHandlerSentinel: class Agent: - """Core Agent interface. + """Core Agent implementation. An agent orchestrates the following workflow: diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py new file mode 100644 index 000000000..b35ade8c4 --- /dev/null +++ b/src/strands/agent/base.py @@ -0,0 +1,66 @@ +"""Agent Interface. + +Defines the minimal interface that all agent types must implement. +""" + +from typing import Any, AsyncIterator, Protocol, runtime_checkable + +from ..types.agent import AgentInput +from .agent_result import AgentResult + + +@runtime_checkable +class AgentBase(Protocol): + """Protocol defining the interface for all agent types in Strands. + + This protocol defines the minimal contract that all agent implementations + must satisfy. + """ + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Returns: + AgentResult containing the agent's response. + """ + ... + + def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent. + **kwargs: Additional arguments. + + Yields: + Events representing the streaming execution. + """ + ... From 10a8e4a1fdfabebe7341ab5d1584731657e15ddc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 12:04:39 -0500 Subject: [PATCH 048/279] ci: update pytest requirement from <9.0.0,>=8.0.0 to >=8.0.0,<10.0.0 in the dev-dependencies group (#1161) * ci: update pytest requirement in the dev-dependencies group Updates the requirements on [pytest](https://github.com/pytest-dev/pytest) to permit the latest version. Updates `pytest` to 9.0.0 - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/8.0.0...9.0.0) --- updated-dependencies: - dependency-name: pytest dependency-version: 9.0.0 dependency-type: direct:development dependency-group: dev-dependencies ... Signed-off-by: dependabot[bot] * bump pytest version floor --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Dean Schmigelski --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 040babe67..05a385ca9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dev = [ "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.6.0", - "pytest>=8.0.0,<9.0.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", @@ -142,7 +142,7 @@ installer = "uv" features = ["all"] extra-args = ["-n", "auto", "-vv"] dependencies = [ - "pytest>=8.0.0,<9.0.0", + "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", From cd6570b197a13957a5dc9d42ce367c20bbd6875d Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 12 Jan 2026 10:27:02 -0800 Subject: [PATCH 049/279] feat(models): pass invocation_state to model providers (#1414) --------- Co-authored-by: Tirth Patel Co-authored-by: Dean Schmigelski --- src/strands/event_loop/event_loop.py | 1 + src/strands/event_loop/streaming.py | 3 +++ src/strands/models/model.py | 2 ++ tests/strands/agent/test_agent.py | 10 +++++++++- tests/strands/event_loop/test_event_loop.py | 1 + tests/strands/event_loop/test_streaming.py | 5 +++++ .../event_loop/test_streaming_structured_output.py | 2 ++ 7 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index fcb530a0d..231cfa56a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -345,6 +345,7 @@ async def _handle_model_execution( tool_specs, system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, + invocation_state=invocation_state, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..7840bfcef 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -425,6 +425,7 @@ async def stream_messages( *, tool_choice: Optional[Any] = None, system_prompt_content: Optional[list[SystemContentBlock]] = None, + invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -437,6 +438,7 @@ async def stream_messages( tool_choice: Optional tool choice constraint for forcing specific tool usage. system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -453,6 +455,7 @@ async def stream_messages( system_prompt, tool_choice=tool_choice, system_prompt_content=system_prompt_content, + invocation_state=invocation_state, ) async for event in process_stream(chunks, start_time): diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..6b7dd78d7 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -73,6 +73,7 @@ def stream( *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Stream conversation with the model. @@ -89,6 +90,7 @@ def stream( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. system_prompt_content: System prompt content blocks for advanced features like caching. + invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. **kwargs: Additional keyword arguments for future extensibility. Yields: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index f133400a8..351eadc84 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -36,7 +36,11 @@ @pytest.fixture def mock_model(request): async def stream(*args, **kwargs): - result = mock.mock_stream(*copy.deepcopy(args), **copy.deepcopy(kwargs)) + # Skip deep copy of invocation_state which contains non-serializable objects (agent, spans, etc.) + copied_kwargs = { + key: value if key == "invocation_state" else copy.deepcopy(value) for key, value in kwargs.items() + } + result = mock.mock_stream(*copy.deepcopy(args), **copied_kwargs) # If result is already an async generator, yield from it if hasattr(result, "__aiter__"): async for item in result: @@ -325,6 +329,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -363,6 +368,7 @@ def test_agent__call__( system_prompt, tool_choice=None, system_prompt_content=[{"text": system_prompt}], + invocation_state=unittest.mock.ANY, ), ], ) @@ -484,6 +490,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -629,6 +636,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene unittest.mock.ANY, tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 6b23bd592..639e60ea0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -383,6 +383,7 @@ async def test_event_loop_cycle_tool_result( "p1", tool_choice=None, system_prompt_content=unittest.mock.ANY, + invocation_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..b2cc152cb 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1117,6 +1117,7 @@ async def test_stream_messages(agenerator, alist): "test prompt", tool_choice=None, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) @@ -1150,6 +1151,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1183,6 +1185,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera "You are a helpful assistant.", tool_choice=None, system_prompt_content=system_prompt_content, + invocation_state=None, ) @@ -1214,6 +1217,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=[], + invocation_state=None, ) @@ -1245,6 +1249,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): None, tool_choice=None, system_prompt_content=None, + invocation_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4645e1724..4c4082c00 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -66,6 +66,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): "test prompt", tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], + invocation_state=None, ) # Verify we get the expected events @@ -131,6 +132,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): "Extract user information", tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], + invocation_state=None, ) assert len(tru_events) > 0 From 37d0e470e6dcc796cb1c40a388f3b2ba432cc5f2 Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Mon, 12 Jan 2026 13:30:48 -0500 Subject: [PATCH 050/279] Add Security.md file (#1454) * Create SECURITY.md --- SECURITY.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..b520ee1fb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,20 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | ------------------ | +| 1.x.x | :white_check_mark: | +| < 1.0 | :x: | + +## Reporting Security Issues + +Amazon Web Services (AWS) is dedicated to the responsible disclosure of security vulnerabilities. + +We kindly ask that you **do not** open a public GitHub issue to report security concerns. + +Instead, please submit the issue to the AWS Vulnerability Disclosure Program via [HackerOne](https://hackerone.com/aws_vdp) or send your report via [email](mailto:aws-security@amazon.com). + +For more details, visit the [AWS Vulnerability Reporting Page](http://aws.amazon.com/security/vulnerability-reporting/). + +Thank you in advance for collaborating with us to help protect our customers. From 845c6f76f84975f20d17a32ad02fbb86c3d695d1 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 12 Jan 2026 13:58:53 -0500 Subject: [PATCH 051/279] chore: Update release notes sop (#1456) Update the release notes SOP that was most effective after several iterations of having it generate the v1.21.0 release notes - final flow is at https://github.com/zastrowm/sdk-python/issues/19. The flow on this was tweaking the SOP and then running the agent, and attempting to have it update in response. The biggest problems to solve was that: - It blindly trusted examples in PRs, resulting in examples that were wrong because of stale PRs - It would not generate useful tests I do not have the step-by-steps results with the different prompt variants as I deleted the older issues to avoid too many references being added to old PRs, but viewing the [final issue with results](https://github.com/zastrowm/sdk-python/issues/19) shows the SOP results with one shot run after verifying which features should go in. --------- Co-authored-by: Mackenzie Zastrow --- .github/agent-sops/task-release-notes.sop.md | 462 +++++++++++++------ 1 file changed, 324 insertions(+), 138 deletions(-) diff --git a/.github/agent-sops/task-release-notes.sop.md b/.github/agent-sops/task-release-notes.sop.md index 5f024da82..e32a0f2eb 100644 --- a/.github/agent-sops/task-release-notes.sop.md +++ b/.github/agent-sops/task-release-notes.sop.md @@ -8,6 +8,22 @@ You analyze merged pull requests between two git references (tags or branches), **Important**: You are executing in an ephemeral environment. Any files you create (test files, notes, etc.) will be discarded after execution. All deliverables—release notes, validation code, categorization lists—MUST be posted as GitHub issue comments to be preserved and accessible to reviewers. +## Key Principles + +These principles apply throughout the entire workflow and are referenced by name in later sections. + +### Principle 1: Ephemeral Environment +You are executing in an ephemeral environment. All deliverables MUST be posted as GitHub issue comments to be preserved. + +### Principle 2: PR Descriptions May Be Stale +PR descriptions are written at PR creation and may become outdated after code review. Reviewers often request structural changes, API modifications, or feature adjustments that are implemented but NOT reflected in the original description. You MUST cross-reference descriptions with review comments and treat merged code as the source of truth. + +### Principle 3: Validation Is Mandatory +You MUST attempt to validate EVERY code example with behavioral tests. The engineer review fallback is only for cases where you have genuinely tried and failed with documented evidence. + +### Principle 4: Never Remove Features +You MUST NOT remove a feature from release notes because validation failed. Always include a code sample—either validated or marked for engineer review. + ## Steps ### 1. Setup and Input Processing @@ -62,10 +78,10 @@ For each PR identified (from release or API query), fetch additional metadata ne - You MUST retrieve additional metadata for PRs being considered for Major Features or Major Bug Fixes: - PR description/body (essential for understanding the change) - PR labels (if any) + - PR review comments and conversation threads (per **Principle 2**) - You SHOULD retrieve for Major Feature candidates: - Files changed in the PR (to find code examples) -- You MAY retrieve: - - PR review comments if helpful for understanding the change +- You MUST retrieve PR review comments for Major Feature and Major Bug Fix candidates to identify post-description changes - You SHOULD minimize API calls by only fetching detailed metadata for PRs that appear significant based on title/prefix - You MUST track this data for use in categorization and release notes generation @@ -89,18 +105,24 @@ Extract categorization signals from PR titles using conventional commit prefixes - You SHOULD record the prefix-based category for each PR - You MAY encounter PRs without conventional commit prefixes -#### 2.2 Analyze PR Descriptions +#### 2.2 Analyze PR Descriptions and Review Comments Use LLM analysis to understand the significance and user impact of each change. **Constraints:** - You MUST read and analyze the PR description for each PR +- Per **Principle 2**, you MUST also review PR comments and review threads to identify changes made after the initial description: + - Look for reviewer comments requesting changes to the implementation + - Look for author responses confirming changes were made + - Look for "LGTM" or approval comments that reference specific modifications + - Pay special attention to comments about API changes, renamed methods, or restructured code +- You MUST treat the actual merged code as the source of truth when descriptions conflict with review feedback - You MUST assess the user-facing impact of the change: - Does it introduce new functionality users will interact with? - Does it fix a bug that users experienced? - Is it purely internal with no user-visible changes? - You MUST identify if the change introduces breaking changes -- You SHOULD identify if the PR includes code examples in its description +- You SHOULD identify if the PR includes code examples in its description (but verify they match the final implementation) - You SHOULD note any links to documentation or related issues - You MAY consider the size and complexity of the change @@ -152,6 +174,10 @@ Present the categorized PRs to the user for review and confirmation. - You MUST wait for user confirmation or recategorization before proceeding - You SHOULD update your categorization based on user feedback - You MAY iterate on categorization if the user requests changes +- When the user promotes a PR to "Major Features" that was not previously in that category: + - You MUST perform Step 3 (Code Snippet Extraction) for the newly promoted PR + - You MUST perform Step 4 (Code Validation) for any code snippets extracted or generated + - You MUST include the validation code for newly promoted features in the Validation Comment (Step 6.1) ### 3. Code Snippet Extraction and Generation @@ -163,12 +189,16 @@ Search merged PRs for existing code that demonstrates the new feature. **Constraints:** - You MUST search each Major Feature PR for existing code examples in: - - Test files (especially integration tests or example tests) + - Test files (especially integration tests or example tests) - these are most reliable as they reflect the final implementation - Example applications or scripts in `examples/` directory - - Code snippets in the PR description + - Code snippets in the PR description (but verify per **Principle 2**) - Documentation updates that include code examples - README updates with usage examples -- You MUST prioritize test files that show real usage of the feature +- You MUST cross-reference any examples from PR descriptions with: + - Review comments that may have requested API changes + - The actual merged code to ensure the example is still accurate + - Test files which reflect the working implementation +- You MUST prioritize test files that show real usage of the feature (these are validated against the final code) - You SHOULD look for the simplest, most focused examples - You SHOULD prefer examples that are already validated (from test files) - You MAY examine multiple PRs if a feature spans several PRs @@ -208,60 +238,178 @@ When existing examples are insufficient, generate new code snippets. ### 4. Code Validation -**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Validation must occur AFTER snippets have been extracted or generated in Step 3. +**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Per **Principle 3**, you MUST attempt validation for every example. -#### 4.1 Create Temporary Test Files +#### 4.1 Validation Requirements -Create temporary test files to validate the code snippets. +Validation tests MUST verify the actual behavior of the feature, not just syntax correctness. A test that only checks whether code parses or imports succeed is NOT valid validation. + +**Available Testing Resources:** +- **Amazon Bedrock**: You have access to Bedrock models for testing. Use Bedrock when a feature requires a real model provider. +- **Project test fixtures**: The project includes mocked model providers and test utilities (commonly in `tests/fixtures/`, `__mocks__/`, or similar) +- **Integration test patterns**: Examine integration test directories (commonly in `tests_integ/` or `test/integ`) for patterns that test real model interactions + +**Features that genuinely cannot be validated (rare):** +- Features requiring paid third-party API credentials with no mock option AND no Bedrock alternative +- Features requiring specific hardware (GPU, TPU) +- Features requiring live network access to specific external services that cannot be mocked **Constraints:** - You MUST create a temporary test file for each code snippet - You MUST place test files in an appropriate test directory based on the project structure - You MUST include all necessary imports and setup code in the test file - You MUST wrap the snippet in a proper test case +- You MUST include assertions that verify the feature's actual behavior: + - Assert that outputs match expected values + - Assert that state changes occur as expected + - Assert that callbacks/hooks are invoked correctly + - Assert that return types and structures are correct +- You MUST NOT write tests that only verify: + - Code parses without syntax errors + - Imports succeed + - Objects can be instantiated without checking behavior + - Functions can be called without checking results - You SHOULD use the project's testing framework -- You MAY need to mock dependencies or setup test fixtures +- You SHOULD mock external dependencies (APIs, databases) but still verify behavior with mocks +- You MAY need to setup test fixtures that enable behavioral verification - You MAY include additional test code that doesn't appear in the release notes -**Example test file structure** (language-specific format will vary): +**Example of GOOD validation** (verifies behavior) - adapt syntax to project language: +```python +def test_structured_output_validation(): + """Verify that structured output actually validates against the schema.""" + from pydantic import BaseModel + + class UserResponse(BaseModel): + name: str + age: int + + agent = Agent(model=mock_model, output_schema=UserResponse) + result = agent("Get user info") + + # Behavioral assertions - verify the feature works + assert isinstance(result.output, UserResponse) + assert hasattr(result.output, 'name') + assert hasattr(result.output, 'age') + assert isinstance(result.output.age, int) ``` -# Test structure depends on the project's testing framework -# Include necessary imports, setup, and the snippet being validated -# Add assertions to verify the code works correctly + +**Example of BAD validation** (only verifies syntax) - adapt syntax to project language: +```python +def test_structured_output_syntax(): + """BAD: This only verifies the code runs without errors.""" + from pydantic import BaseModel + + class UserResponse(BaseModel): + name: str + age: int + + # BAD: No assertions about behavior + agent = Agent(model=mock_model, output_schema=UserResponse) + # BAD: Just calling without checking results proves nothing + agent("Get user info") ``` -#### 4.2 Run Validation Tests +#### 4.2 Validation Workflow -Execute tests to ensure code snippets are valid and functional. +For each Major Feature, follow this workflow in order: + +1. **Write a test file** with behavioral assertions +2. **Run the test** using the project's test framework +3. **If it fails**, try these approaches in order: + - Try using Bedrock instead of other model providers + - Try installing missing dependencies + - Try mocking external services + - Try using project test fixtures (e.g., mocked model providers) + - Try simplifying the example +4. **Document each attempt** and its result in the Validation Comment +5. **Only after documented failures** can you use the engineer review fallback **Constraints:** - You MUST run the appropriate test command for the project (e.g., `npm test`, `pytest`, `go test`) - You MUST verify that the test passes successfully +- You MUST verify that assertions actually executed (not skipped or short-circuited) - You MUST check that the code compiles without errors in compiled languages +- You MUST ensure tests include meaningful assertions about feature behavior - You SHOULD run type checking if applicable (e.g., `npm run type-check`, `mypy`) +- You SHOULD review test output to confirm behavioral assertions passed - You MAY need to adjust imports or setup code if tests fail -- You MAY need to install additional dependencies if required -**Fallback validation** (if test execution fails or is not possible): -- You MUST at minimum validate syntax using the appropriate language tools -- You MUST ensure the code is syntactically correct -- You MUST verify all referenced types and modules exist +**Installing Dependencies:** +- You MUST attempt to install missing dependencies when tests fail due to import errors +- You SHOULD check the project's dependency manifest (`pyproject.toml`, `package.json`, `Cargo.toml`, etc.) for optional dependency groups +- You SHOULD use the project's package manager to install dependencies (e.g., `pip install`, `npm install`, `cargo add`) +- For projects with optional extras, use the appropriate syntax (e.g., `pip install -e ".[extra]"` for Python, `npm install --save-dev` for Node.js) +- You SHOULD only fall back to mocking if the dependency cannot be installed (e.g., requires paid API keys, proprietary software) + +**Example of mocking external dependencies** - adapt syntax to project language: +```python +def test_custom_http_client(): + """Verify custom HTTP client is passed to the provider.""" + from unittest.mock import Mock, patch + + custom_client = Mock() + + with patch('strands.models.openai.OpenAI') as mock_openai: + from strands.models.openai import OpenAIModel + model = OpenAIModel(http_client=custom_client) + + # Verify the custom client was passed + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + assert call_kwargs.get('http_client') == custom_client +``` + +#### 4.3 Engineer Review Fallback -#### 4.3 Handle Validation Failures +When validation genuinely fails after documented attempts, use this fallback. Per **Principle 4**, you MUST still include the feature with a code sample. -Address any validation failures before including snippets in release notes. +**Required proof before using this fallback:** +1. Created an actual test file (show the code in the validation comment) +2. Ran the test and received an actual error (show the error message) +3. Tried at least ONE alternative approach (Bedrock, mocking, simplified example) +4. Documented each attempt and its failure reason **Constraints:** -- You MUST NOT include unvalidated code snippets in release notes -- You MUST revise the code snippet if validation fails -- You MUST re-run validation after making changes -- You SHOULD examine the actual implementation in the PR if generated code fails -- You SHOULD simplify the example if complexity is causing validation issues -- You MAY extract a different example from the PR if the current one cannot be validated -- You MAY seek clarification if you cannot create a valid example -- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.2) +- You MUST NOT mark examples as needing validation without actually attempting validation first +- You MUST NOT use vague reasons like "complex setup required" - be specific about what you tried and what error you got +- You MUST show your test code and error messages in the Validation Comment +- You MUST try Bedrock for any feature that works with multiple model providers before giving up +- You MUST try mocking for provider-specific features before giving up +- You MUST document all validation attempts (successful AND failed) in the Validation Comment +- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.1) +- You MUST note in the validation comment what specific behavior each test verifies - You MAY delete temporary test files after capturing their content, as the environment is ephemeral +**Process when validation genuinely fails:** +1. **Extract a code sample from the PR** - Use code from: + - The PR description's code examples + - Test files added in the PR + - The actual implementation (simplified for readability) + - Documentation updates in the PR +2. **Include the sample in the release notes** with a clear callout that it needs engineer validation +3. **Document the validation attempts and failures** in the Validation Comment (Step 6.1) + +**Format for unvalidated code examples:** +```markdown +### Feature Name - [PR#123](link) + +Description of the feature and its impact. + +\`\`\`python +# ⚠️ NEEDS ENGINEER VALIDATION +# Validation attempted: [describe test created and error received] +# Alternative attempts: [what else you tried and why it failed] + +# Code sample extracted from PR description/tests +from strands import Agent +from strands.models.openai import OpenAIModel + +model = OpenAIModel(http_client=custom_client) +agent = Agent(model=model) +\`\`\` +``` + ### 5. Release Notes Formatting #### 5.1 Format Major Features Section @@ -289,9 +437,16 @@ Create the Major Features section with concise descriptions and code examples. Agents can now validate responses against predefined schemas with configurable retry behavior for non-conforming outputs. -\`\`\`[language] -# Code example in the project's programming language -# Show the feature in action with clear, focused code +\`\`\`python +from strands import Agent +from pydantic import BaseModel + +class Response(BaseModel): + answer: str + +agent = Agent(output_schema=Response) +result = agent("What is 2+2?") +print(result.output.answer) \`\`\` See the [Structured Output docs](https://docs.example.com/structured-output) for configuration options. @@ -336,63 +491,82 @@ Add a horizontal rule to separate your content from GitHub's auto-generated sect - This visually separates your curated content from GitHub's auto-generated "What's Changed" and "New Contributors" sections - You MUST NOT include a "Full Changelog" link—GitHub adds this automatically -**Example format**: -```markdown -## Major Bug Fixes - -- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) - Description of what was fixed. - ---- -``` - ### 6. Output Delivery -**Critical**: You are running in an ephemeral environment. All files created during execution (test files, temporary notes, etc.) will be deleted when the workflow completes. You MUST post all deliverables as GitHub issue comments—this is the only way to preserve your work and make it accessible to reviewers. +Per **Principle 1**, all deliverables must be posted as GitHub issue comments. -**Comment Structure**: Post exactly two comments on the GitHub issue: +**Comment Structure**: Post exactly three comments on the GitHub issue: 1. **Validation Comment** (first): Contains all validation code for all features in one batched comment 2. **Release Notes Comment** (second): Contains the final formatted release notes +3. **Exclusions Comment** (third): Documents any features that were excluded and why + +This ordering allows reviewers to see the validation evidence, review the release notes, and understand any exclusion decisions. -This ordering allows reviewers to see the validation evidence before reviewing the release notes. +**Iteration Comments**: If the user requests changes after the initial comments are posted: +- Post additional validation comments for any re-validated code +- Post updated release notes as new comments (do not edit previous comments) +- This creates an audit trail of changes and validations #### 6.1 Post Validation Code Comment Batch all validation code into a single GitHub issue comment. **Constraints:** -- You MUST post ONE comment containing ALL validation code for ALL features +- You MUST post ONE comment containing validation attempts for ALL Major Features +- You MUST show test code for EVERY feature - both successful and failed attempts - You MUST NOT post separate comments for each feature's validation - You MUST post this comment BEFORE the release notes comment - You MUST include all test files created during validation (Step 4) in this single comment +- You MUST document what specific behavior each test verifies (not just "validates the code works") - You MUST NOT reference local file paths—the ephemeral environment will be destroyed - You MUST clearly label this comment as "Code Validation Tests" -- You MUST include a note explaining that this code was used to validate the snippets in the release notes -- You SHOULD use collapsible `
` sections to organize validation code by feature: - ```markdown - ## Code Validation Tests +- You SHOULD use collapsible `
` sections to organize validation code by feature +- You SHOULD include a brief description of what behavior is being verified for each test - The following test code was used to validate the code examples in the release notes. +**Format:** +```markdown +## Code Validation Tests + +The following test code was used to validate the code examples in the release notes. -
- Validation: Feature Name 1 +
+✅ Validated: Feature Name 1 - \`\`\`typescript - [Full test file for feature 1] - \`\`\` +**Behavior verified:** This test confirms that the new `output_schema` parameter causes the agent to return a validated Pydantic model instance with the correct field types. -
+\`\`\`python +[Full test file for feature 1 with behavioral assertions] +\`\`\` -
- Validation: Feature Name 2 +**Test output:** PASSED - \`\`\`typescript - [Full test file for feature 2] - \`\`\` +
-
- ``` -- This allows reviewers to copy and run the validation code themselves +
+⚠️ Could Not Validate: Feature Name 2 + +**Attempt 1: Direct test with mocked model** +\`\`\`python +[Test code that was attempted] +\`\`\` +**Error received:** +\`\`\` +[Actual error message from running the test] +\`\`\` + +**Attempt 2: Test with Bedrock** +\`\`\`python +[Alternative test code attempted] +\`\`\` +**Error received:** +\`\`\` +[Actual error message] +\`\`\` + +**Conclusion:** Could not validate because [specific reason based on actual errors]. Code sample in release notes extracted from PR description. + +
+``` #### 6.2 Post Release Notes Comment @@ -408,95 +582,117 @@ Post the formatted release notes as a single GitHub issue comment. - You MAY use markdown formatting in the comment - If comment posting is deferred, continue with the workflow and note the deferred status -## Examples +#### 6.3 Post Exclusions Comment -### Example 1: Major Features Section with Code +Document any features with unvalidated code samples and any other notable decisions. +**Constraints:** +- You MUST post this comment as the FINAL comment on the GitHub issue +- You MUST include this comment if ANY of the following occurred: + - A Major Feature has an unvalidated code sample (marked for engineer review) + - A feature's scope or description was significantly different from the PR description + - You relied on review comments rather than the PR description to understand a feature +- You MUST clearly explain the reasoning for each unvalidated sample +- You SHOULD include this comment even if all code samples were validated, with a simple note: "All code samples were successfully validated. No engineer review required." +- You MUST NOT skip this comment—it provides critical transparency for reviewers + +**Format:** ```markdown -## Major Features - -### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) - -MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. +## Release Notes Review Notes -\`\`\`[language] -# Code example in the project's programming language -# Demonstrate the key feature usage -# Keep it focused and concise -\`\`\` +The following items require attention during review: -See the [MCP docs](https://docs.example.com/mcp) for details. +### ⚠️ Features with Unvalidated Code Samples -### Async Streaming for Multi-Agent Systems - [PR#961](https://github.com/org/repo/pull/961) +These features have code samples extracted from PRs but could not be automatically validated. An engineer must verify these examples before publishing: -Multi-agent systems now support async streaming, enabling real-time event streaming from agent teams as they collaborate. +- **PR#123 - Feature Title**: + - Code source: PR description / test files / implementation + - Validation attempted: [what you tried] + - Failure reason: [why it failed, e.g., "requires OpenAI API credentials", "complex multi-service integration"] + - Action needed: Engineer should verify the code sample works as shown -\`\`\`[language] -# Another code example -# Show the feature in action -# Include only essential code -\`\`\` +### Description vs. Implementation Discrepancies +- **PR#101 - Feature Title**: PR description stated [X] but review comments and final implementation show [Y]. Release notes reflect the actual merged behavior. ``` -### Example 2: Major Bug Fixes Section +#### 6.4 Handle User Feedback on Release Notes -```markdown ---- +When the user requests changes to the release notes after they have been posted, re-validate as needed. -## Major Bug Fixes - -- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) - Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. - -- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) - Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. +**Constraints:** +- You MUST re-run validation (Step 4) when the user requests changes that affect code examples: + - Modified code snippets + - New code examples for features that previously had none + - Replacement examples for features +- You MUST perform full extraction (Step 3) and validation (Step 4) when the user requests: + - Adding a new feature to the release notes that wasn't previously included + - Promoting a bug fix to include a code example +- You MUST NOT make changes to code examples without re-validating them +- You MUST post updated validation code as a new comment when re-validation occurs +- You MUST post the revised release notes as a new comment (do not edit previous comments) +- You SHOULD note in the updated release notes comment what changed from the previous version +- You MAY skip re-validation only for changes that do not affect code: + - Wording changes to descriptions + - Fixing typos + - Reordering features + - Removing features (no validation needed for removal) -- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) - Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. -``` +## Examples -### Example 3: Complete Release Notes Structure +### Example 1: Complete Release Notes ```markdown ## Major Features -### Feature Name - [PR#123](https://github.com/owner/repo/pull/123) +### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) -Description of the feature and its impact. +MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. + +\`\`\`python +from strands import Agent +from strands.tools import MCPToolProvider -\`\`\`[language] -# Code example demonstrating the feature +provider = MCPToolProvider(server_config) +agent = Agent(tools=[provider]) +result = agent("Use the MCP tools") \`\`\` ---- +See the [MCP docs](https://docs.example.com/mcp) for details. -## Major Bug Fixes +### Custom HTTP Client Support - [PR#1366](https://github.com/org/repo/pull/1366) -- **Critical Fix** - [PR#124](https://github.com/owner/repo/pull/124) - Description of what was fixed and why it matters. +OpenAI model provider now accepts a custom HTTP client, enabling proxy configuration, custom timeouts, and request logging. ---- -``` +\`\`\`python +# ⚠️ NEEDS ENGINEER VALIDATION +# Validation attempted: mocked OpenAI client, received import error +# Alternative attempts: Bedrock (not applicable - OpenAI-specific) -Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. +from strands.models.openai import OpenAIModel +import httpx -### Example 4: Issue Comment with Release Notes +custom_client = httpx.Client(proxy="http://proxy.example.com:8080") +model = OpenAIModel(client_args={"http_client": custom_client}) +\`\`\` -```markdown -Release notes for v1.15.0: +--- -## Major Features +## Major Bug Fixes -### Managed MCP Connections - [PR#895](https://github.com/strands-agents/sdk-typescript/pull/895) +- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) + Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. -We've introduced MCP Connections via ToolProviders... +- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) + Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. -[... rest of release notes ...] +- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) + Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. --- ``` -When this content is added to the GitHub release, GitHub will automatically append the "What's Changed" and "New Contributors" sections below the separator. +Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. ## Troubleshooting @@ -519,14 +715,7 @@ If you encounter GitHub API rate limit errors: ### Code Validation Failures -If code validation fails for a snippet: -1. Review the test output to understand the failure reason -2. Check if the feature requires additional dependencies or setup -3. Examine the actual implementation in the PR to understand correct usage -4. Try simplifying the example to focus on core functionality -5. Consider using a different example from the PR -6. If unable to validate, note the issue in the release notes comment and skip the code example for that feature -7. Leave a comment on the issue noting which features couldn't include validated code examples +Follow the validation workflow in Section 4.2. If all attempts fail, use the engineer review fallback per Section 4.3. Per **Principle 4**, always include a code sample. ### Large PR Sets (>100 PRs) @@ -561,22 +750,19 @@ When GitHub tools or git operations are deferred (GITHUB_WRITE=false): - The operations will be executed after agent completion - Do not retry or attempt alternative approaches for deferred operations -### Unable to Extract Suitable Code Examples +### Stale PR Descriptions -If no suitable code examples can be found or generated for a feature: -1. Examine the PR description more carefully for usage information -2. Look at related documentation changes -3. Consider whether the feature actually needs a code example (some features are self-explanatory) -4. Generate a minimal example based on the API changes, even if you can't fully validate it -5. Mark the example as "conceptual" if validation isn't possible -6. Consider omitting the code example if it would be misleading +Per **Principle 2**: Review PR comments for context on what changed, examine merged code (especially test files), and use test files as the authoritative source for code examples. ## Desired Outcome * Focused release notes highlighting Major Features and Major Bug Fixes with concise descriptions (2-3 sentences, no bullet points) -* Working, validated code examples for all major features +* Code examples for ALL major features - either validated or marked for engineer review +* Validated code examples have passing behavioral tests +* Unvalidated code examples are clearly marked with the engineer validation warning and extracted from PR sources * Well-formatted markdown that renders properly on GitHub * Release notes posted as a comment on the GitHub issue for review +* Review notes comment documenting any features with unvalidated code samples that need engineer attention **Important**: Your generated release notes will be prepended to GitHub's auto-generated release notes. GitHub automatically generates: - "What's Changed" section listing all PRs with authors and links From 3ffc327071396fa24f805c03524da6b71e5f73cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Mon, 12 Jan 2026 14:08:49 -0500 Subject: [PATCH 052/279] fix(integ): make calculator tool more robust to LLM output variations (#1445) The test_tool_use_with_structured_output test was flaky because the LLM sometimes uses '+' instead of 'add' as the operation string. The calculator tool now accepts both formats for all operations. Changes: - Accept both word and symbol forms: add/+, subtract/-, multiply/*, divide//, power/** - Also accept common abbreviations: sub, mul, div, pow - Normalize input with lower() and strip() - Fix divide operation (was b/a, now a/b) - Improve docstring with Args section This makes the integ tests more resilient to LLM output variations. Co-authored-by: Strands Coder --- .../test_structured_output_agent_loop.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 188f57777..390bd3cff 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -132,16 +132,23 @@ def validate_first_name(cls, value: str) -> str: @tool def calculator(operation: str, a: float, b: float) -> float: - """Simple calculator tool for testing.""" - if operation == "add": + """Simple calculator tool for testing. + + Args: + operation: The operation to perform. One of: add, subtract, multiply, divide, power + a: The first number + b: The second number + """ + op = operation.lower().strip() + if op in ("add", "+"): return a + b - elif operation == "subtract": + elif op in ("subtract", "-", "sub"): return a - b - elif operation == "multiply": + elif op in ("multiply", "*", "mul"): return a * b - elif operation == "divide": - return b / a if a != 0 else 0 - elif operation == "power": + elif op in ("divide", "/", "div"): + return a / b if b != 0 else 0 + elif op in ("power", "**", "pow"): return a**b else: return 0 From 56676c19297b95e0396f74c0fb9c2afc0a96c25e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Mon, 12 Jan 2026 14:16:43 -0500 Subject: [PATCH 053/279] fix(mcp): resolve string formatting error in MCP client error handling (#1446) --- src/strands/tools/mcp/mcp_client.py | 2 +- tests/strands/tools/mcp/test_mcp_client.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 37b99d021..db21b9ef2 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -713,7 +713,7 @@ async def _handle_error_message(self, message: Exception | Any) -> None: if isinstance(message, Exception): error_msg = str(message).lower() if any(pattern in error_msg for pattern in _NON_FATAL_ERROR_PATTERNS): - self._log_debug_with_thread("ignoring non-fatal MCP session error", message) + self._log_debug_with_thread("ignoring non-fatal MCP session error: %s", message) else: raise message await anyio.lowlevel.checkpoint() diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 35f11f47f..f784da414 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -924,3 +924,21 @@ def test_list_resource_templates_sync_session_not_active(): with pytest.raises(MCPClientInitializationError, match="client session is not running"): client.list_resource_templates_sync() + + +@pytest.mark.asyncio +async def test_handle_error_message_with_percent_in_message(): + """Test that _handle_error_message handles messages containing % characters without string formatting errors. + + This is a regression test for issue #1244 where MCP error messages containing '%' characters + (e.g., from URLs like "https://example.com/path?param=value%20encoded") would cause a + TypeError: not all arguments converted during string formatting. + """ + client = MCPClient(MagicMock()) + + # Test with a message that contains % characters (like URL-encoded strings) + # This simulates the error that occurs when MCP servers return messages with % in them + error_with_percent = Exception("unknown request id: abc%20123%30def") + + # This should not raise TypeError and should not raise the exception (since it's non-fatal) + await client._handle_error_message(error_with_percent) From 318573d0618c283f6f16ace1117d5e3d76279568 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 12 Jan 2026 14:34:30 -0500 Subject: [PATCH 054/279] bidi - move 3.12 check to nova sonic module (#1439) --- src/strands/experimental/bidi/__init__.py | 8 --- src/strands/experimental/bidi/agent/agent.py | 15 +++--- .../experimental/bidi/models/__init__.py | 15 +++--- src/strands/experimental/bidi/models/model.py | 3 +- .../experimental/bidi/models/nova_sonic.py | 9 ++++ .../bidi/_async/test_task_group.py | 3 ++ .../experimental/bidi/agent/test_agent.py | 20 +++++--- .../experimental/bidi/agent/test_loop.py | 4 +- .../bidi/models/test_nova_sonic.py | 50 ++++++++++++------- .../bidi/models/test_openai_realtime.py | 3 +- 10 files changed, 75 insertions(+), 55 deletions(-) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 57986062e..1c0e74aae 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,10 +1,5 @@ """Bidirectional streaming package.""" -import sys - -if sys.version_info < (3, 12): - raise ImportError("bidi only supported for >= Python 3.12") - # Main components - Primary user interface # Re-export standard agent events for tool handling from ...types._events import ( @@ -19,7 +14,6 @@ # Model interface (for custom implementations) from .models.model import BidiModel -from .models.nova_sonic import BidiNovaSonicModel # Built-in tools from .tools import stop_conversation @@ -48,8 +42,6 @@ "BidiAgent", # IO channels "BidiAudioIO", - # Model providers - "BidiNovaSonicModel", # Built-in tools "stop_conversation", # Input Event types diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 5ddb181ea..11bea96e5 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -32,7 +32,6 @@ from ...tools import ToolProvider from .._async import _TaskGroup, stop_all from ..models.model import BidiModel -from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, @@ -100,13 +99,13 @@ def __init__( ValueError: If model configuration is invalid or state is invalid type. TypeError: If model type is unsupported. """ - self.model = ( - BidiNovaSonicModel() - if not model - else BidiNovaSonicModel(model_id=model) - if isinstance(model, str) - else model - ) + if isinstance(model, BidiModel): + self.model = model + else: + from ..models.nova_sonic import BidiNovaSonicModel + + self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel() + self.system_prompt = system_prompt self.messages = messages or [] diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 6e5817046..7b87e09fe 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -3,27 +3,26 @@ from typing import Any from .model import BidiModel, BidiModelTimeoutError -from .nova_sonic import BidiNovaSonicModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiNovaSonicModel", ] def __getattr__(name: str) -> Any: - """ - Lazy load bidi model implementations only when accessed. - - This defers the import of optional dependencies until actually needed: - - BidiGeminiLiveModel requires google-generativeai (lazy loaded) - - BidiOpenAIRealtimeModel requires openai (lazy loaded) + """Lazy load bidi model implementations only when accessed. + + This defers the import of optional dependencies until actually needed. """ if name == "BidiGeminiLiveModel": from .gemini_live import BidiGeminiLiveModel return BidiGeminiLiveModel + if name == "BidiNovaSonicModel": + from .nova_sonic import BidiNovaSonicModel + + return BidiNovaSonicModel if name == "BidiOpenAIRealtimeModel": from .openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py index f5e34aa50..5941d7e41 100644 --- a/src/strands/experimental/bidi/models/model.py +++ b/src/strands/experimental/bidi/models/model.py @@ -14,7 +14,7 @@ """ import logging -from typing import Any, AsyncIterable, Protocol +from typing import Any, AsyncIterable, Protocol, runtime_checkable from ....types._events import ToolResultEvent from ....types.content import Messages @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) +@runtime_checkable class BidiModel(Protocol): """Protocol for bidirectional streaming models. diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 6a2477e22..1c946220d 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -11,8 +11,15 @@ - Tool execution with content containers and identifier tracking - 8-minute connection limits with proper cleanup sequences - Interruption detection through stopReason events + +Note, BidiNovaSonicModel is only supported for Python 3.12+ """ +import sys + +if sys.version_info < (3, 12): + raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+") + import asyncio import base64 import json @@ -93,6 +100,8 @@ class BidiNovaSonicModel(BidiModel): Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidiModel interface. + Note, BidiNovaSonicModel is only supported for Python 3.12+. + Attributes: _stream: open bedrock stream to nova sonic. """ diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py index b9a30ef5b..255ead15e 100644 --- a/tests/strands/experimental/bidi/_async/test_task_group.py +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -19,6 +19,7 @@ async def test_task_group__aexit__(): @pytest.mark.asyncio async def test_task_group__aexit__task_exception(): wait_event = asyncio.Event() + async def wait(): await wait_event.wait() @@ -49,12 +50,14 @@ async def wait(): @pytest.mark.asyncio async def test_task_group__aexit__context_cancelled(): wait_event = asyncio.Event() + async def wait(): await wait_event.wait() tasks = [] run_event = asyncio.Event() + async def run(): async with _TaskGroup() as task_group: tasks.append(task_group.create_task(wait())) diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 7b03ab717..50c9afef9 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,13 +1,13 @@ """Unit tests for BidiAgent.""" import asyncio +import sys import unittest.mock from uuid import uuid4 import pytest from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -125,13 +125,6 @@ def test_bidi_agent_init_with_various_configurations(): assert agent_with_config.system_prompt == system_prompt assert agent_with_config.agent_id == "test_agent" - # Test with string model ID - model_id = "amazon.nova-sonic-v1:0" - agent_with_string = BidiAgent(model=model_id) - - assert isinstance(agent_with_string.model, BidiNovaSonicModel) - assert agent_with_string.model.model_id == model_id - # Test model config access config = agent.model.config assert config["audio"]["input_rate"] == 16000 @@ -139,6 +132,17 @@ def test_bidi_agent_init_with_various_configurations(): assert config["audio"]["channels"] == 1 +@pytest.mark.skipif(sys.version_info < (3, 12), reason="BidiNovaSonicModel is only supported for Python 3.12+") +def test_bidi_agent_init_with_model_id(): + from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel + + model_id = "amazon.nova-sonic-v1:0" + agent = BidiAgent(model=model_id) + + assert isinstance(agent.model, BidiNovaSonicModel) + assert agent.model.model_id == model_id + + @pytest.mark.asyncio async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index da8578f55..fac52658e 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -5,7 +5,7 @@ from strands import tool from strands.experimental.bidi import BidiAgent -from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -21,7 +21,7 @@ async def func(): @pytest.fixture def agent(time_tool): - return BidiAgent(model=unittest.mock.AsyncMock(), tools=[time_tool]) + return BidiAgent(model=unittest.mock.AsyncMock(spec=BidiModel), tools=[time_tool]) @pytest_asyncio.fixture diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 933fd2088..7435d4ad2 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -4,10 +4,17 @@ covering connection lifecycle, event conversion, audio streaming, and tool execution. """ +import sys + +if sys.version_info < (3, 12): + import pytest + + pytest.skip(reason="BidiNovaSonicModel is only supported for Python 3.12+", allow_module_level=True) + import asyncio import base64 import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest import pytest_asyncio @@ -39,9 +46,8 @@ def model_id(): @pytest.fixture -def region(): - """AWS region.""" - return "us-east-1" +def boto_session(): + return Mock(region_name="us-east-1") @pytest.fixture @@ -67,11 +73,11 @@ def mock_client(mock_stream): @pytest_asyncio.fixture -def nova_model(model_id, region, mock_client): +def nova_model(model_id, boto_session, mock_client): """Create Nova Sonic model instance.""" _ = mock_client - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) yield model @@ -79,12 +85,12 @@ def nova_model(model_id, region, mock_client): @pytest.mark.asyncio -async def test_model_initialization(model_id, region): +async def test_model_initialization(model_id, boto_session): """Test model initialization with configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.model_id == model_id - assert model.region == region + assert model.region == "us-east-1" assert model._connection_id is None @@ -92,9 +98,9 @@ async def test_model_initialization(model_id, region): @pytest.mark.asyncio -async def test_audio_config_defaults(model_id, region): +async def test_audio_config_defaults(model_id, boto_session): """Test default audio configuration.""" - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) assert model.config["audio"]["input_rate"] == 16000 assert model.config["audio"]["output_rate"] == 16000 @@ -104,10 +110,12 @@ async def test_audio_config_defaults(model_id, region): @pytest.mark.asyncio -async def test_audio_config_partial_override(model_id, region): +async def test_audio_config_partial_override(model_id, boto_session): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 24000, "voice": "ruth"}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 24000 @@ -120,7 +128,7 @@ async def test_audio_config_partial_override(model_id, region): @pytest.mark.asyncio -async def test_audio_config_full_override(model_id, region): +async def test_audio_config_full_override(model_id, boto_session): """Test full audio configuration override.""" provider_config = { "audio": { @@ -131,7 +139,9 @@ async def test_audio_config_full_override(model_id, region): "voice": "stephen", } } - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -527,11 +537,13 @@ async def test_message_history_empty_and_edge_cases(nova_model): @pytest.mark.asyncio -async def test_custom_audio_rates_in_events(model_id, region): +async def test_custom_audio_rates_in_events(model_id, boto_session): """Test that audio events use configured sample rates.""" # Create model with custom audio configuration provider_config = {"audio": {"output_rate": 48000, "channels": 2}} - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + model = BidiNovaSonicModel( + model_id=model_id, client_config={"boto_session": boto_session}, provider_config=provider_config + ) # Test audio output event uses custom configuration audio_bytes = b"test audio data" @@ -548,10 +560,10 @@ async def test_custom_audio_rates_in_events(model_id, region): @pytest.mark.asyncio -async def test_default_audio_rates_in_events(model_id, region): +async def test_default_audio_rates_in_events(model_id, boto_session): """Test that audio events use default sample rates when no custom config.""" # Create model without custom audio configuration - model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + model = BidiNovaSonicModel(model_id=model_id, client_config={"boto_session": boto_session}) # Test audio output event uses defaults audio_bytes = b"test audio data" diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 1cabbc92b..09f4c8bc8 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -9,6 +9,7 @@ """ import base64 +import itertools import json import unittest.mock @@ -522,7 +523,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): @unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") @pytest.mark.asyncio async def test_receive_timeout(mock_time, model): - mock_time.side_effect = [1, 2] + mock_time.side_effect = itertools.count() model.timeout_s = 1 await model.start() From 68257a3e9f792b06d4be771677c6e38c10e99da7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:12:19 -0500 Subject: [PATCH 055/279] ci: update sphinx requirement from <9.0.0,>=5.0.0 to >=5.0.0,<10.0.0 (#1426) Updates the requirements on [sphinx](https://github.com/sphinx-doc/sphinx) to permit the latest version. - [Release notes](https://github.com/sphinx-doc/sphinx/releases) - [Changelog](https://github.com/sphinx-doc/sphinx/blob/master/CHANGES.rst) - [Commits](https://github.com/sphinx-doc/sphinx/compare/v5.0.0...v9.1.0) --- updated-dependencies: - dependency-name: sphinx dependency-version: 9.1.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 05a385ca9..62e0e04b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ sagemaker = [ ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ - "sphinx>=5.0.0,<9.0.0", + "sphinx>=5.0.0,<10.0.0", "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<4.0.0", ] From 02738013252a3c39c9dfc3f972ba88f3cfb02afe Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:30:27 -0500 Subject: [PATCH 056/279] fix: add concurrency protection to prevent parallel invocations from corrupting agent state (#1453) When multiple invocations occur concurrently on the same Agent instance the internal agent state can become corrupted, causing subsequent invocations to fail. The most common result is that the number of toolUse blocks end up out of sync with subsequent toolResult blocks, resulting in ValidationExceptions as reported in the bug report (#1176). To block multiple conccurrent agent invocations, we'll raise a new ConcurrencyException before any state modification occurs. --------- Co-authored-by: Strands Agent Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 88 ++++++---- src/strands/tools/_caller.py | 79 +++++---- src/strands/types/exceptions.py | 11 ++ tests/strands/agent/test_agent.py | 260 +++++++++++++++++++++++++++++- 4 files changed, 372 insertions(+), 66 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index c4ebc0b54..7126644e6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -10,6 +10,7 @@ """ import logging +import threading import warnings from typing import ( TYPE_CHECKING, @@ -59,7 +60,7 @@ from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException +from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -245,6 +246,11 @@ def __init__( self._interrupt_state = _InterruptState() + # Initialize lock for guarding concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads, so asyncio.Lock wouldn't work + self._invocation_lock = threading.Lock() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -554,6 +560,7 @@ async def stream_async( - And other event data provided by the callback handler Raises: + ConcurrencyException: If another invocation is already in progress on this agent instance. Exception: Any exceptions from the agent invocation will be propagated to the caller. Example: @@ -563,50 +570,63 @@ async def stream_async( yield event["data"] ``` """ - self._interrupt_state.resume(prompt) + # Acquire lock to prevent concurrent invocations + # Using threading.Lock instead of asyncio.Lock because run_async() creates + # separate event loops in different threads + acquired = self._invocation_lock.acquire(blocking=False) + if not acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) - self.event_loop_metrics.reset_usage_metrics() + try: + self._interrupt_state.resume(prompt) - merged_state = {} - if kwargs: - warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) - merged_state.update(kwargs) - if invocation_state is not None: - merged_state["invocation_state"] = invocation_state - else: - if invocation_state is not None: - merged_state = invocation_state + self.event_loop_metrics.reset_usage_metrics() - callback_handler = self.callback_handler - if kwargs: - callback_handler = kwargs.get("callback_handler", self.callback_handler) + merged_state = {} + if kwargs: + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) + merged_state.update(kwargs) + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state + else: + if invocation_state is not None: + merged_state = invocation_state - # Process input and get message to add (if any) - messages = await self._convert_prompt_to_messages(prompt) + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) - self.trace_span = self._start_agent_trace_span(messages) + # Process input and get message to add (if any) + messages = await self._convert_prompt_to_messages(prompt) - with trace_api.use_span(self.trace_span): - try: - events = self._run_loop(messages, merged_state, structured_output_model) + self.trace_span = self._start_agent_trace_span(messages) - async for event in events: - event.prepare(invocation_state=merged_state) + with trace_api.use_span(self.trace_span): + try: + events = self._run_loop(messages, merged_state, structured_output_model) + + async for event in events: + event.prepare(invocation_state=merged_state) - if event.is_callback_event: - as_dict = event.as_dict() - callback_handler(**as_dict) - yield as_dict + if event.is_callback_event: + as_dict = event.as_dict() + callback_handler(**as_dict) + yield as_dict - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield AgentResultEvent(result=result).as_dict() + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield AgentResultEvent(result=result).as_dict() - self._end_agent_trace_span(response=result) + self._end_agent_trace_span(response=result) - except Exception as e: - self._end_agent_trace_span(error=e) - raise + except Exception as e: + self._end_agent_trace_span(error=e) + raise + + finally: + self._invocation_lock.release() async def _run_loop( self, diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 97485d068..bfec5886d 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -15,6 +15,7 @@ from ..tools.executors._executor import ToolExecutor from ..types._events import ToolInterruptEvent from ..types.content import ContentBlock, Message +from ..types.exceptions import ConcurrencyException from ..types.tools import ToolResult, ToolUse if TYPE_CHECKING: @@ -73,46 +74,64 @@ def caller( if self._agent._interrupt_state.activated: raise RuntimeError("cannot directly call tool during interrupt") - normalized_name = self._find_normalized_tool_name(name) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + should_lock = should_record_direct_tool_call - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + from ..agent import Agent # Locally imported to avoid circular reference - tool_result = tool_results[0] + acquired_lock = ( + should_lock + and isinstance(self._agent, Agent) + and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." + ) - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + try: + normalized_name = self._find_normalized_tool_name(name) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs - return tool_result + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + + tool_result = tool_results[0] + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._record_tool_execution(tool_use, tool_result, user_message_override) - tool_result = run_async(acall) + return tool_result - # TODO: https://github.com/strands-agents/sdk-python/issues/1311 - from ..agent import Agent + tool_result = run_async(acall) - if isinstance(self._agent, Agent): - self._agent.conversation_manager.apply_management(self._agent) + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) + + return tool_result - return tool_result + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() return caller diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index b9c5bc769..1d1983abd 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -94,3 +94,14 @@ def __init__(self, message: str): """ self.message = message super().__init__(message) + + +class ConcurrencyException(Exception): + """Exception raised when concurrent invocations are attempted on an agent instance. + + Agent instances maintain internal state that cannot be safely accessed concurrently. + This exception is raised when an invocation is attempted while another invocation + is already in progress on the same agent instance. + """ + + pass diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 351eadc84..81ce65989 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,17 +1,21 @@ +import asyncio import copy import importlib import json import os import textwrap +import threading +import time import unittest.mock import warnings +from typing import Any, AsyncGenerator from uuid import uuid4 import pytest from pydantic import BaseModel import strands -from strands import Agent +from strands import Agent, ToolContext from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager @@ -24,7 +28,7 @@ from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages -from strands.types.exceptions import ContextWindowOverflowException, EventLoopException +from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -189,6 +193,15 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") +class SlowMockedModel(MockedModelProvider): + async def stream( + self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs + ) -> AsyncGenerator[Any, None]: + await asyncio.sleep(0.15) # Add async delay to ensure concurrency + async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): + yield event + + def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_imported, tool_registry): _ = tool_registry @@ -2190,3 +2203,246 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): # Should not have added any toolResult messages # Only the new user message and assistant response should be added assert len(agent.messages) == original_length + 2 + + +# ============================================================================ +# Concurrency Exception Tests +# ============================================================================ + + +def test_agent_concurrent_call_raises_exception(): + """Test that concurrent __call__() calls raise ConcurrencyException.""" + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +def test_agent_concurrent_structured_output_raises_exception(): + """Test that concurrent structured_output() calls raise ConcurrencyException. + + Note: This test validates that the sync invocation path is protected. + The concurrent __call__() test already validates the core functionality. + """ + model = SlowMockedModel( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + ] + ) + agent = Agent(model=model) + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Create two threads that will try to invoke concurrently + t1 = threading.Thread(target=invoke) + t2 = threading.Thread(target=invoke) + + t1.start() + time.sleep(0.05) # Small delay to ensure first thread acquires lock + t2.start() + t1.join() + t2.join() + + # One should succeed, one should raise ConcurrencyException + assert len(results) == 1, f"Expected 1 success, got {len(results)}" + assert len(errors) == 1, f"Expected 1 error, got {len(errors)}" + assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() + + +@pytest.mark.asyncio +async def test_agent_sequential_invocations_work(): + """Test that sequential invocations work correctly after lock is released.""" + model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "response1"}]}, + {"role": "assistant", "content": [{"text": "response2"}]}, + {"role": "assistant", "content": [{"text": "response3"}]}, + ] + ) + agent = Agent(model=model) + + # All sequential calls should succeed + result1 = await agent.invoke_async("test1") + assert result1.message["content"][0]["text"] == "response1" + + result2 = await agent.invoke_async("test2") + assert result2.message["content"][0]["text"] == "response2" + + result3 = await agent.invoke_async("test3") + assert result3.message["content"][0]["text"] == "response3" + + +@pytest.mark.asyncio +async def test_agent_lock_released_on_exception(): + """Test that lock is released when an exception occurs during invocation.""" + + # Create a mock model that raises an explicit error + mock_model = unittest.mock.Mock() + + async def failing_stream(*args, **kwargs): + raise RuntimeError("Simulated model failure") + yield # Make this an async generator + + mock_model.stream = failing_stream + + agent = Agent(model=mock_model) + + # First call will fail due to the simulated error + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + # Lock should be released, so this should not raise ConcurrencyException + # It will still raise RuntimeError, but that's expected + with pytest.raises(RuntimeError, match="Simulated model failure"): + await agent.invoke_async("test") + + +def test_agent_direct_tool_call_during_invocation_raises_exception(tool_decorated): + """Test that direct tool call during agent invocation raises ConcurrencyException.""" + + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=True) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have not succeeded + assert len(tool_calls) == 0 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [ + { + "text": "Error: ConcurrencyException - Direct tool call cannot be made while the agent is " + "in the middle of an invocation. Set record_direct_tool_call=False to allow direct tool " + "calls during agent invocation." + } + ], + "status": "error", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } + + +def test_agent_direct_tool_call_during_invocation_succeeds_with_record_false(tool_decorated): + """Test that direct tool call during agent invocation succeeds when record_direct_tool_call=False.""" + tool_calls = [] + + @strands.tool + def tool_to_invoke(): + tool_calls.append("tool_to_invoke") + return "called" + + @strands.tool(context=True) + def agent_tool(tool_context: ToolContext) -> str: + tool_context.agent.tool.tool_to_invoke(record_direct_tool_call=False) + return "tool result" + + model = MockedModelProvider( + [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "test-123", + "name": "agent_tool", + "input": {}, + } + } + ], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + ] + ) + agent = Agent(model=model, tools=[agent_tool, tool_to_invoke]) + agent("Hi") + + # Tool call should have succeeded + assert len(tool_calls) == 1 + + assert agent.messages[-2] == { + "content": [ + { + "toolResult": { + "content": [{"text": "tool result"}], + "status": "success", + "toolUseId": "test-123", + } + } + ], + "role": "user", + } From c098b3df9da3ef848eeb5fee066912e30c3b797d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Tue, 13 Jan 2026 11:52:27 -0500 Subject: [PATCH 057/279] fix(mcp): propagate contextvars to background thread (#1444) Fixes #1440 The MCP client creates a background thread for connection management. Previously, context variables set in the main thread were not accessible in this background thread. This change copies the context from the main thread when starting the background thread, ensuring that contextvars are properly propagated. This is consistent with the fix in PR #1146 which addressed the same issue for tool invocations. Changes: - Add contextvars import - Use contextvars.copy_context() and ctx.run() when creating background thread - Add test to verify context propagation Co-authored-by: Strands Coder --- src/strands/tools/mcp/mcp_client.py | 7 +- .../tools/mcp/test_mcp_client_contextvar.py | 88 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/strands/tools/mcp/test_mcp_client_contextvar.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index db21b9ef2..ea11627b9 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -9,6 +9,7 @@ import asyncio import base64 +import contextvars import logging import threading import uuid @@ -179,7 +180,11 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client session is currently running") self._log_debug_with_thread("entering MCPClient context") - self._background_thread = threading.Thread(target=self._background_task, args=[], daemon=True) + # Copy context vars to propagate to the background thread + # This ensures that context set in the main thread is accessible in the background thread + # See: https://github.com/strands-agents/sdk-python/issues/1440 + ctx = contextvars.copy_context() + self._background_thread = threading.Thread(target=ctx.run, args=(self._background_task,), daemon=True) self._background_thread.start() self._log_debug_with_thread("background thread started, waiting for ready event") try: diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py new file mode 100644 index 000000000..d95929b02 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -0,0 +1,88 @@ +"""Test for MCP client context variable propagation. + +This test verifies that context variables set in the main thread are +properly propagated to the MCP client's background thread. + +Related: https://github.com/strands-agents/sdk-python/issues/1440 +""" + +import contextvars +import threading +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from strands.tools.mcp import MCPClient + + +@pytest.fixture +def mock_transport(): + """Create mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +# Context variable for testing +test_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar("test_contextvar", default="default_value") + + +def test_mcp_client_propagates_contextvars_to_background_thread(mock_transport, mock_session): + """Test that context variables are propagated to the MCP client background thread. + + This verifies the fix for https://github.com/strands-agents/sdk-python/issues/1440 + where context variables set in the main thread were not accessible in the + MCP client's background thread. + """ + # Store the value seen in the background thread + background_thread_value = {} + + # Patch _background_task to capture the contextvar value + original_background_task = MCPClient._background_task + + def capturing_background_task(self): + # Capture the contextvar value in the background thread + background_thread_value["contextvar"] = test_contextvar.get() + background_thread_value["thread_id"] = threading.current_thread().ident + # Call the original background task + return original_background_task(self) + + # Set a specific value in the main thread + test_contextvar.set("main_thread_value") + main_thread_id = threading.current_thread().ident + + with patch.object(MCPClient, "_background_task", capturing_background_task): + with MCPClient(mock_transport["transport_callable"]) as client: + # Verify the client started successfully + assert client._background_thread is not None + + # Verify context was propagated to background thread + assert "contextvar" in background_thread_value, "Background task should have run and captured contextvar" + assert background_thread_value["contextvar"] == "main_thread_value", ( + f"Context variable should be propagated to background thread. " + f"Expected 'main_thread_value', got '{background_thread_value['contextvar']}'" + ) + # Verify it was indeed a different thread + assert background_thread_value["thread_id"] != main_thread_id, "Background task should run in a different thread" From 06c32974f914a0f181dcd54c33b8be690a526e9b Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 13 Jan 2026 12:28:21 -0500 Subject: [PATCH 058/279] Update to opus 4.5 (#1471) --- .github/scripts/python/agent_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py index 9d92c2ac4..1f772241c 100644 --- a/.github/scripts/python/agent_runner.py +++ b/.github/scripts/python/agent_runner.py @@ -39,7 +39,7 @@ from str_replace_based_edit_tool import str_replace_based_edit_tool # Strands configuration constants -STRANDS_MODEL_ID = "global.anthropic.claude-sonnet-4-5-20250929-v1:0" +STRANDS_MODEL_ID = "global.anthropic.claude-opus-4-5-20251101-v1:0" STRANDS_MAX_TOKENS = 64000 STRANDS_BUDGET_TOKENS = 8000 STRANDS_REGION = "us-west-2" From c43dfa930e1f87b3a42dc3ba1dc09046321c865a Mon Sep 17 00:00:00 2001 From: Ratish P <114130421+Ratish1@users.noreply.github.com> Date: Thu, 15 Jan 2026 00:02:58 +0530 Subject: [PATCH 059/279] fix(mcp): prevent agent hang by checking session closure state (#1396) --- src/strands/tools/mcp/mcp_client.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ea11627b9..c36811c17 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -891,4 +891,10 @@ def _matches_patterns(self, tool: MCPAgentTool, patterns: list[_ToolMatcher]) -> return False def _is_session_active(self) -> bool: - return self._background_thread is not None and self._background_thread.is_alive() + if self._background_thread is None or not self._background_thread.is_alive(): + return False + + if self._close_future is not None and self._close_future.done(): + return False + + return True From 368bb0f719bed57741e77f8e2dc4aeab62bdca5f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:37:11 -0500 Subject: [PATCH 060/279] ci: update sphinx-rtd-theme requirement (#1466) Updates the requirements on [sphinx-rtd-theme](https://github.com/readthedocs/sphinx_rtd_theme) to permit the latest version. - [Changelog](https://github.com/readthedocs/sphinx_rtd_theme/blob/master/docs/changelog.rst) - [Commits](https://github.com/readthedocs/sphinx_rtd_theme/compare/1.0.0...3.1.0) --- updated-dependencies: - dependency-name: sphinx-rtd-theme dependency-version: 3.1.0 dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 62e0e04b3..7a6b02d53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ sagemaker = [ otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<10.0.0", - "sphinx-rtd-theme>=1.0.0,<2.0.0", + "sphinx-rtd-theme>=1.0.0,<4.0.0", "sphinx-autodoc-typehints>=1.12.0,<4.0.0", ] From c0298319ee00ab7c88ce7087b702a544395e1e3a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:39:04 -0500 Subject: [PATCH 061/279] ci: update websockets requirement (#1451) Updates the requirements on [websockets](https://github.com/python-websockets/websockets) to permit the latest version. - [Release notes](https://github.com/python-websockets/websockets/releases) - [Commits](https://github.com/python-websockets/websockets/compare/15.0...16.0) --- updated-dependencies: - dependency-name: websockets dependency-version: '16.0' dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7a6b02d53..aa5f773c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ bidi = [ "smithy-aws-core>=0.0.1; python_version>='3.12'", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] -bidi-openai = ["websockets>=15.0.0,<16.0.0"] +bidi-openai = ["websockets>=15.0.0,<17.0.0"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] From 2546aa0d3e7081b4ca6e45340789343539f81d27 Mon Sep 17 00:00:00 2001 From: Max Rabin <927792+maxrabin@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:51:09 +0200 Subject: [PATCH 062/279] style: update ruff configuration to apply pyupgrade to modernize syntax (#1336) --------- Co-authored-by: Dean Schmigelski --- pyproject.toml | 1 + src/strands/_async.py | 3 +- src/strands/agent/agent.py | 59 ++++++------- src/strands/agent/agent_result.py | 3 +- src/strands/agent/base.py | 3 +- .../conversation_manager.py | 6 +- .../null_conversation_manager.py | 4 +- .../sliding_window_conversation_manager.py | 8 +- .../summarizing_conversation_manager.py | 14 +-- src/strands/event_loop/event_loop.py | 3 +- src/strands/event_loop/streaming.py | 11 +-- src/strands/experimental/agent_config.py | 2 +- .../steering/handlers/__init__.py | 2 +- .../steering/handlers/llm/llm_handler.py | 2 +- .../experimental/tools/tool_provider.py | 3 +- src/strands/hooks/events.py | 12 +-- src/strands/hooks/registry.py | 7 +- src/strands/models/_validation.py | 5 +- src/strands/models/anthropic.py | 19 +++-- src/strands/models/bedrock.py | 77 ++++++++--------- src/strands/models/gemini.py | 29 ++++--- src/strands/models/litellm.py | 29 ++++--- src/strands/models/llamaapi.py | 27 +++--- src/strands/models/llamacpp.py | 36 ++++---- src/strands/models/mistral.py | 29 ++++--- src/strands/models/model.py | 11 +-- src/strands/models/ollama.py | 33 +++---- src/strands/models/openai.py | 29 +++---- src/strands/models/sagemaker.py | 45 +++++----- src/strands/models/writer.py | 27 +++--- src/strands/multiagent/a2a/executor.py | 6 +- src/strands/multiagent/base.py | 5 +- src/strands/multiagent/graph.py | 35 ++++---- src/strands/multiagent/swarm.py | 11 +-- src/strands/session/file_session_manager.py | 16 ++-- .../session/repository_session_manager.py | 4 +- src/strands/session/s3_session_manager.py | 26 +++--- src/strands/session/session_repository.py | 12 +-- src/strands/telemetry/metrics.py | 49 +++++------ src/strands/telemetry/tracer.py | 85 +++++++++---------- src/strands/tools/_caller.py | 3 +- src/strands/tools/decorator.py | 25 +++--- src/strands/tools/executors/_executor.py | 3 +- src/strands/tools/executors/concurrent.py | 3 +- src/strands/tools/executors/sequential.py | 3 +- src/strands/tools/loader.py | 12 +-- src/strands/tools/mcp/mcp_client.py | 22 ++--- src/strands/tools/mcp/mcp_instrumentation.py | 9 +- src/strands/tools/registry.py | 29 ++++--- .../_structured_output_context.py | 10 +-- .../structured_output_tool.py | 10 +-- .../structured_output_utils.py | 28 +++--- src/strands/tools/watcher.py | 6 +- src/strands/types/_events.py | 3 +- src/strands/types/citations.py | 22 ++--- src/strands/types/collections.py | 4 +- src/strands/types/content.py | 14 +-- src/strands/types/guardrails.py | 24 +++--- src/strands/types/media.py | 6 +- src/strands/types/session.py | 4 +- src/strands/types/streaming.py | 22 +++-- src/strands/types/tools.py | 16 +--- src/strands/types/traces.py | 32 +++---- tests/fixtures/mock_hook_provider.py | 7 +- .../fixtures/mock_multiagent_hook_provider.py | 7 +- tests/fixtures/mocked_model_provider.py | 19 +++-- .../strands/agent/hooks/test_hook_registry.py | 3 +- tests/strands/agent/test_agent_result.py | 4 +- .../agent/test_agent_structured_output.py | 3 +- tests/strands/models/test_sagemaker.py | 16 ++-- tests/strands/models/test_writer.py | 6 +- .../session/test_file_session_manager.py | 8 +- .../test_structured_output_context.py | 4 +- .../test_structured_output_tool.py | 5 +- tests/strands/tools/test_decorator.py | 23 ++--- tests/strands/tools/test_structured_output.py | 20 ++--- tests_integ/mcp/echo_server.py | 4 +- tests_integ/mcp/test_mcp_client.py | 6 +- tests_integ/models/providers.py | 4 +- tests_integ/test_function_tools.py | 3 +- tests_integ/test_multiagent_graph.py | 3 +- .../test_structured_output_agent_loop.py | 12 ++- 82 files changed, 626 insertions(+), 629 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index aa5f773c4..b49c74d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ select = [ "G", # logging format "I", # isort "LOG", # logging + "UP" # pyupgrade ] [tool.ruff.lint.per-file-ignores] diff --git a/src/strands/_async.py b/src/strands/_async.py index 141ca71b7..0ceb038f3 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -2,8 +2,9 @@ import asyncio import contextvars +from collections.abc import Awaitable, Callable from concurrent.futures import ThreadPoolExecutor -from typing import Awaitable, Callable, TypeVar +from typing import TypeVar T = TypeVar("T") diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7126644e6..b58b55f24 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -12,15 +12,10 @@ import logging import threading import warnings +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, - Callable, - Mapping, - Optional, - Type, TypeVar, Union, cast, @@ -105,26 +100,24 @@ class Agent: def __init__( self, - model: Union[Model, str, None] = None, - messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, - system_prompt: Optional[str | list[SystemContentBlock]] = None, - structured_output_model: Optional[Type[BaseModel]] = None, - callback_handler: Optional[ - Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] - ] = _DEFAULT_CALLBACK_HANDLER, - conversation_manager: Optional[ConversationManager] = None, + model: Model | str | None = None, + messages: Messages | None = None, + tools: list[Union[str, dict[str, str], "ToolProvider", Any]] | None = None, + system_prompt: str | list[SystemContentBlock] | None = None, + structured_output_model: type[BaseModel] | None = None, + callback_handler: Callable[..., Any] | _DefaultCallbackHandlerSentinel | None = _DEFAULT_CALLBACK_HANDLER, + conversation_manager: ConversationManager | None = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, *, - agent_id: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - state: Optional[Union[AgentState, dict]] = None, - hooks: Optional[list[HookProvider]] = None, - session_manager: Optional[SessionManager] = None, - tool_executor: Optional[ToolExecutor] = None, + agent_id: str | None = None, + name: str | None = None, + description: str | None = None, + state: AgentState | dict | None = None, + hooks: list[HookProvider] | None = None, + session_manager: SessionManager | None = None, + tool_executor: ToolExecutor | None = None, ): """Initialize the Agent with the specified configuration. @@ -190,7 +183,7 @@ def __init__( # If not provided, create a new PrintingCallbackHandler instance # If explicitly set to None, use null_callback_handler # Otherwise use the passed callback_handler - self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + self.callback_handler: Callable[..., Any] | PrintingCallbackHandler if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): self.callback_handler = PrintingCallbackHandler() elif callback_handler is None: @@ -227,7 +220,7 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() - self.trace_span: Optional[trace_api.Span] = None + self.trace_span: trace_api.Span | None = None # Initialize agent state management if state is not None: @@ -325,7 +318,7 @@ def __call__( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -366,7 +359,7 @@ async def invoke_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -403,7 +396,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: + def structured_output(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -434,7 +427,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> return run_async(lambda: self.structured_output_async(output_model, prompt)) - async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: + async def structured_output_async(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -529,7 +522,7 @@ async def stream_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -632,7 +625,7 @@ async def _run_loop( self, messages: Messages, invocation_state: dict[str, Any], - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -794,8 +787,8 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index ef8a11029..2ab95e5b5 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -3,8 +3,9 @@ This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. """ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence, cast +from typing import Any, cast from pydantic import BaseModel diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py index b35ade8c4..ae8a14e75 100644 --- a/src/strands/agent/base.py +++ b/src/strands/agent/base.py @@ -3,7 +3,8 @@ Defines the minimal interface that all agent types must implement. """ -from typing import Any, AsyncIterator, Protocol, runtime_checkable +from collections.abc import AsyncIterator +from typing import Any, Protocol, runtime_checkable from ..types.agent import AgentInput from .agent_result import AgentResult diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 47b761abc..690ecbde5 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,7 +1,7 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message @@ -62,7 +62,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """ pass - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. Args: @@ -98,7 +98,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass @abstractmethod - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..11632525d 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,6 +1,6 @@ """Null implementation of conversation management.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -28,7 +28,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Does not reduce context and raises an exception. Args: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index a063e55eb..709c876e7 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,7 @@ """Sliding window conversation history management.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -103,7 +103,7 @@ def get_state(self) -> dict[str, Any]: state["model_call_count"] = self._model_call_count return state - def restore_from_session(self, state: dict[str, Any]) -> Optional[list]: + def restore_from_session(self, state: dict[str, Any]) -> list | None: """Restore the conversation manager's state from a session. Args: @@ -136,7 +136,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: return self.reduce_context(agent) - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -235,7 +235,7 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: return changes_made - def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + def _find_last_message_with_tool_results(self, messages: Messages) -> int | None: """Find the index of the last message containing tool results. This is useful for identifying messages that might need to be truncated to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 12185c286..cc71e4d88 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -1,7 +1,7 @@ """Summarizing conversation history management with configurable options.""" import logging -from typing import TYPE_CHECKING, Any, List, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override @@ -62,7 +62,7 @@ def __init__( summary_ratio: float = 0.3, preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, - summarization_system_prompt: Optional[str] = None, + summarization_system_prompt: str | None = None, ): """Initialize the summarizing conversation manager. @@ -87,10 +87,10 @@ def __init__( self.preserve_recent_messages = preserve_recent_messages self.summarization_agent = summarization_agent self.summarization_system_prompt = summarization_system_prompt - self._summary_message: Optional[Message] = None + self._summary_message: Message | None = None @override - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restores the Summarizing Conversation manager from its previous state in a session. Args: @@ -121,7 +121,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: # No proactive management - summarization only happens on context overflow pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. Args: @@ -173,7 +173,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. Args: @@ -224,7 +224,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry - def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. Uses the same logic as SlidingWindowConversationManager for consistency. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 231cfa56a..99c8f5179 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,7 +11,8 @@ import asyncio import logging import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from opentelemetry import trace as trace_api diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 7840bfcef..954633807 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -4,7 +4,8 @@ import logging import time import warnings -from typing import Any, AsyncGenerator, AsyncIterable, Optional +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from ..models.model import Model from ..tools import InvalidToolUseNameException @@ -419,13 +420,13 @@ async def process_stream( async def stream_messages( model: Model, - system_prompt: Optional[str], + system_prompt: str | None, messages: Messages, tool_specs: list[ToolSpec], *, - tool_choice: Optional[Any] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, - invocation_state: Optional[dict[str, Any]] = None, + tool_choice: Any | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index f65afb57d..e6fb94118 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -98,7 +98,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {file_path}") - with open(config_path, "r") as f: + with open(config_path) as f: config_dict = json.load(f) elif isinstance(config, dict): config_dict = config.copy() diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index 542126ab5..fe364a5a2 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,5 +1,5 @@ """Steering handler implementations.""" -from typing import Sequence +from collections.abc import Sequence __all__: Sequence[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 9d9b34911..4d90f46c9 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction: """Provide contextual guidance for tool usage. Args: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 2c79ceafc..c40d1b572 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -1,7 +1,8 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...types.tools import AgentTool diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 5e11524d1..340b6d3d2 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -116,7 +116,7 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): the tool call and use a default cancel message. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] cancel_tool: bool | str = False @@ -157,11 +157,11 @@ class AfterToolCallEvent(HookEvent): cancel_message: The cancellation message if the user cancelled the tool call. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] result: ToolResult - exception: Optional[Exception] = None + exception: Exception | None = None cancel_message: str | None = None def _can_write(self, name: str) -> bool: @@ -232,8 +232,8 @@ class ModelStopResponse: message: Message stop_reason: StopReason - stop_response: Optional[ModelStopResponse] = None - exception: Optional[Exception] = None + stop_response: ModelStopResponse | None = None + exception: Exception | None = None retry: bool = False def _can_write(self, name: str) -> bool: diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 9edf7ffa7..309e3ba76 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,8 +9,9 @@ import inspect import logging +from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException @@ -154,9 +155,9 @@ class HookRegistry: def __init__(self) -> None: """Initialize an empty hook registry.""" - self._registered_callbacks: dict[Type, list[HookCallback]] = {} + self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: """Register a callback function for a specific event type. Args: diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 9eabe28a1..1e82bca73 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -1,14 +1,15 @@ """Configuration validation utilities for model providers.""" import warnings -from typing import Any, Mapping, Type +from collections.abc import Mapping +from typing import Any from typing_extensions import get_type_hints from ..types.tools import ToolChoice -def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: +def validate_config_keys(config_dict: Mapping[str, Any], config_class: type) -> None: """Validate that config keys match the TypedDict fields. Args: diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 68b234729..535c820ee 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import anthropic from pydantic import BaseModel @@ -59,9 +60,9 @@ class AnthropicConfig(TypedDict, total=False): max_tokens: Required[int] model_id: Required[str] - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]): + def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. Args: @@ -198,8 +199,8 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -369,8 +370,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -419,8 +420,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 8e1558ca7..dfcd133c6 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -8,7 +8,8 @@ import logging import os import warnings -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast +from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from typing import Any, Literal, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -94,35 +95,35 @@ class BedrockConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature) """ - additional_args: Optional[dict[str, Any]] - additional_request_fields: Optional[dict[str, Any]] - additional_response_field_paths: Optional[list[str]] - cache_prompt: Optional[str] - cache_tools: Optional[str] - guardrail_id: Optional[str] - guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] - guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] - guardrail_version: Optional[str] - guardrail_redact_input: Optional[bool] - guardrail_redact_input_message: Optional[str] - guardrail_redact_output: Optional[bool] - guardrail_redact_output_message: Optional[str] - guardrail_latest_message: Optional[bool] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + additional_request_fields: dict[str, Any] | None + additional_response_field_paths: list[str] | None + cache_prompt: str | None + cache_tools: str | None + guardrail_id: str | None + guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None + guardrail_stream_processing_mode: Literal["sync", "async"] | None + guardrail_version: str | None + guardrail_redact_input: bool | None + guardrail_redact_input_message: str | None + guardrail_redact_output: bool | None + guardrail_redact_output_message: str | None + guardrail_latest_message: bool | None + max_tokens: int | None model_id: str - include_tool_result_status: Optional[Literal["auto"] | bool] - stop_sequences: Optional[list[str]] - streaming: Optional[bool] - temperature: Optional[float] - top_p: Optional[float] + include_tool_result_status: Literal["auto"] | bool | None + stop_sequences: list[str] | None + streaming: bool | None + temperature: float | None + top_p: float | None def __init__( self, *, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, - endpoint_url: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + endpoint_url: str | None = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -193,8 +194,8 @@ def get_config(self) -> BedrockConfig: def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -603,11 +604,11 @@ def _generate_redaction_events(self) -> list[StreamEvent]: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -631,13 +632,13 @@ async def stream( ModelThrottledException: If the model service is throttling requests. """ - def callback(event: Optional[StreamEvent] = None) -> None: + def callback(event: StreamEvent | None = None) -> None: loop.call_soon_threadsafe(queue.put_nowait, event) if event is None: return loop = asyncio.get_event_loop() - queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue() # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: @@ -659,8 +660,8 @@ def _stream( self, callback: Callable[..., None], messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -913,11 +914,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -962,7 +963,7 @@ async def structured_output( yield {"output": output_model(**output_response)} @staticmethod - def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str: + def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str: """Get the default Bedrock modelId based on region. If the region is not **known** to support inference then we show a helpful warning diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 45f7f4e18..52d45b649 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -6,7 +6,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import pydantic from google import genai @@ -54,8 +55,8 @@ class GeminiConfig(TypedDict, total=False): def __init__( self, *, - client: Optional[genai.Client] = None, - client_args: Optional[dict[str, Any]] = None, + client: genai.Client | None = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. @@ -219,7 +220,7 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten for message in messages ] - def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[genai.types.Tool | Any]: + def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Tool @@ -248,9 +249,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge def _format_request_config( self, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> genai.types.GenerateContentConfig: """Format Gemini request config. @@ -273,9 +274,9 @@ def _format_request_config( def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> dict[str, Any]: """Format a Gemini streaming request. @@ -394,8 +395,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: @@ -483,8 +484,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model using Gemini's native structured output. - Docs: https://ai.google.dev/gemini-api/docs/structured-output diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c120b0eda..ae71cc668 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import litellm from litellm.exceptions import ContextWindowExceededError @@ -42,9 +43,9 @@ class LiteLLMConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[LiteLLMConfig]) -> None: """Initialize provider instance. Args: @@ -137,9 +138,9 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for LiteLLM with cache point support. @@ -174,9 +175,9 @@ def _format_system_messages( def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format a LiteLLM compatible messages array with cache point support. @@ -243,11 +244,11 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -295,8 +296,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Some models do not support native structured output via response_format. @@ -322,7 +323,7 @@ async def structured_output( yield {"output": result} async def _structured_output_using_response_schema( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using native response_format support.""" response = await litellm.acompletion( @@ -350,7 +351,7 @@ async def _structured_output_using_response_schema( raise ValueError(f"Failed to parse or load content into model: {e}") from e async def _structured_output_using_tool( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using tool calling fallback.""" tool_spec = convert_pydantic_to_tool_spec(output_model) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..ce0367bf5 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -43,16 +44,16 @@ class LlamaConfig(TypedDict, total=False): """ model_id: str - repetition_penalty: Optional[float] - temperature: Optional[float] - top_p: Optional[float] - max_completion_tokens: Optional[int] - top_k: Optional[int] + repetition_penalty: float | None + temperature: float | None + top_p: float | None + max_completion_tokens: int | None + top_k: int | None def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[LlamaConfig], ) -> None: """Initialize provider instance. @@ -159,7 +160,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": [self._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -206,7 +207,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Llama API chat streaming request. @@ -328,8 +329,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -416,8 +417,8 @@ async def stream( @override def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 22a3a3873..ca838f3d7 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -14,15 +14,11 @@ import logging import mimetypes import time +from collections.abc import AsyncGenerator from typing import ( Any, - AsyncGenerator, - Dict, - Optional, - Type, TypedDict, TypeVar, - Union, cast, ) @@ -133,12 +129,12 @@ class LlamaCppConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( self, base_url: str = "http://localhost:8080", - timeout: Optional[Union[float, tuple[float, float]]] = None, + timeout: float | tuple[float, float] | None = None, **model_config: Unpack[LlamaCppConfig], ) -> None: """Initialize llama.cpp provider instance. @@ -196,7 +192,7 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + def _format_message_content(self, content: ContentBlock | dict[str, Any]) -> dict[str, Any]: """Format a content block for llama.cpp. Args: @@ -233,7 +229,7 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) # Handle audio content (not in standard ContentBlock but supported by llama.cpp) if "audio" in content: - audio_content = cast(Dict[str, Any], content) + audio_content = cast(dict[str, Any], content) audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") audio_format = audio_content["audio"].get("format", "wav") return { @@ -284,7 +280,7 @@ def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: "content": [self._format_message_content(content) for content in contents], } - def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format messages for llama.cpp. Args: @@ -343,8 +339,8 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, ) -> dict[str, Any]: """Format a request for the llama.cpp server. @@ -428,7 +424,7 @@ def _format_request( request[param] = value # Collect llama.cpp-specific parameters for extra_body - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} for param, value in params.items(): if param in llamacpp_specific_params: extra_body[param] = value @@ -511,8 +507,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -552,7 +548,7 @@ async def stream( yield self._format_chunk({"chunk_type": "message_start"}) yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls: Dict[int, list] = {} + tool_calls: dict[int, list] = {} usage_data = None finish_reason = None @@ -706,11 +702,11 @@ async def stream( @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output using llama.cpp's native JSON schema support. This implementation uses llama.cpp's json_schema parameter to constrain @@ -753,7 +749,7 @@ async def structured_output( if "text" in delta: response_text += delta["text"] # Forward events to caller - yield cast(Dict[str, Union[T, Any]], event) + yield cast(dict[str, T | Any], event) # Parse and validate the JSON response data = json.loads(response_text.strip()) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index b6459d63f..4ec77ccfe 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,8 @@ import base64 import json import logging -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable +from typing import Any, TypeVar import mistralai from pydantic import BaseModel @@ -47,16 +48,16 @@ class MistralConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - temperature: Optional[float] - top_p: Optional[float] - stream: Optional[bool] + max_tokens: int | None + temperature: float | None + top_p: float | None + stream: bool | None def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[MistralConfig], ) -> None: """Initialize provider instance. @@ -115,7 +116,7 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + def _format_request_message_content(self, content: ContentBlock) -> str | dict[str, Any]: """Format a Mistral content block. Args: @@ -187,7 +188,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -236,7 +237,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Mistral chat streaming request. @@ -395,8 +396,8 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -502,8 +503,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 6b7dd78d7..e6630f807 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -2,7 +2,8 @@ import abc import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any, TypeVar from pydantic import BaseModel @@ -45,8 +46,8 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -68,8 +69,8 @@ def structured_output( def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..8d72aa534 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import ollama from pydantic import BaseModel @@ -46,20 +47,20 @@ class OllamaConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature). """ - additional_args: Optional[dict[str, Any]] - keep_alive: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + keep_alive: str | None + max_tokens: int | None model_id: str - options: Optional[dict[str, Any]] - stop_sequences: Optional[list[str]] - temperature: Optional[float] - top_p: Optional[float] + options: dict[str, Any] | None + stop_sequences: list[str] | None + temperature: float | None + top_p: float | None def __init__( self, - host: Optional[str], + host: str | None, *, - ollama_client_args: Optional[dict[str, Any]] = None, + ollama_client_args: dict[str, Any] | None = None, **model_config: Unpack[OllamaConfig], ) -> None: """Initialize provider instance. @@ -147,7 +148,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format an Ollama compatible messages array. Args: @@ -167,7 +168,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s ] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format an Ollama chat streaming request. @@ -285,8 +286,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -339,8 +340,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c381201e4..d9266212b 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,8 +7,9 @@ import json import logging import mimetypes +from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from typing import Any, Protocol, TypedDict, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -54,12 +55,12 @@ class OpenAIConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( self, - client: Optional[Client] = None, - client_args: Optional[dict[str, Any]] = None, + client: Client | None = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIConfig], ) -> None: """Initialize provider instance. @@ -201,9 +202,7 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> } @classmethod - def _split_tool_message_images( - cls, tool_message: dict[str, Any] - ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]: + def _split_tool_message_images(cls, tool_message: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]: """Split a tool message into text-only tool message and optional user message with images. OpenAI API restricts images to user role messages only. This method extracts any image @@ -291,9 +290,9 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for OpenAI-compatible providers. @@ -374,9 +373,9 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -549,8 +548,8 @@ async def _get_client(self) -> AsyncIterator[Any]: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -679,8 +678,8 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 1fe630fdc..775969290 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -3,8 +3,9 @@ import json import logging import os +from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, Literal, TypedDict, TypeVar import boto3 from botocore.config import Config as BotocoreConfig @@ -37,7 +38,7 @@ class UsageMetadata: total_tokens: int completion_tokens: int prompt_tokens: int - prompt_tokens_details: Optional[int] = 0 + prompt_tokens_details: int | None = 0 @dataclass @@ -49,8 +50,8 @@ class FunctionCall: arguments: Arguments to pass to the function """ - name: Union[str, dict[Any, Any]] - arguments: Union[str, dict[Any, Any]] + name: str | dict[Any, Any] + arguments: str | dict[Any, Any] def __init__(self, **kwargs: dict[str, str]): """Initialize function call. @@ -108,12 +109,12 @@ class SageMakerAIPayloadSchema(TypedDict, total=False): max_tokens: int stream: bool - temperature: Optional[float] - top_p: Optional[float] - top_k: Optional[int] - stop: Optional[list[str]] - tool_results_as_user_messages: Optional[bool] - additional_args: Optional[dict[str, Any]] + temperature: float | None + top_p: float | None + top_k: int | None + stop: list[str] | None + tool_results_as_user_messages: bool | None + additional_args: dict[str, Any] | None class SageMakerAIEndpointConfig(TypedDict, total=False): """Configuration options for SageMaker models. @@ -127,17 +128,17 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): endpoint_name: str region_name: str - inference_component_name: Union[str, None] - target_model: Union[Optional[str], None] - target_variant: Union[Optional[str], None] - additional_args: Optional[dict[str, Any]] + inference_component_name: str | None + target_model: str | None | None + target_variant: str | None | None + additional_args: dict[str, Any] | None def __init__( self, endpoint_config: SageMakerAIEndpointConfig, payload_config: SageMakerAIPayloadSchema, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, ): """Initialize provider instance. @@ -199,8 +200,8 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -300,8 +301,8 @@ def format_request( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -572,8 +573,8 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index a54fc44c3..f306d649b 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import writerai from pydantic import BaseModel @@ -41,13 +42,13 @@ class WriterConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - stop: Optional[Union[str, List[str]]] - stream_options: Dict[str, Any] - temperature: Optional[float] - top_p: Optional[float] + max_tokens: int | None + stop: str | list[str] | None + stream_options: dict[str, Any] + temperature: float | None + top_p: float | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[WriterConfig]): """Initialize provider instance. Args: @@ -201,7 +202,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": formatted_contents, } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Writer compatible messages array. Args: @@ -245,7 +246,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: """Format a streaming request to the underlying model. @@ -353,8 +354,8 @@ def format_chunk(self, event: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -431,8 +432,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 52b6d2ef1..f02b8c6cc 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -313,15 +313,13 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten elif uri_data: # For URI files, create a text representation since Strands ContentBlocks expect bytes content_blocks.append( - ContentBlock( - text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) - ) + ContentBlock(text=f"[File: {file_name} ({mime_type})] - Referenced file at: {uri_data}") ) elif isinstance(part_root, DataPart): # Handle DataPart - convert structured data to JSON text try: data_text = json.dumps(part_root.data, indent=2) - content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + content_blocks.append(ContentBlock(text=f"[Structured Data]\n{data_text}")) except Exception: logger.exception("Failed to serialize data part") except Exception: diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..dc3258f68 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -6,9 +6,10 @@ import logging import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, Union from .._async import run_async from ..agent import AgentResult @@ -95,7 +96,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": raise TypeError("NodeResult.from_dict: missing 'result'") raw = data["result"] - result: Union[AgentResult, "MultiAgentResult", Exception] + result: AgentResult | MultiAgentResult | Exception if isinstance(raw, dict) and raw.get("type") == "agent_result": result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..19504ad73 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,8 +18,9 @@ import copy import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, cast from opentelemetry import trace as trace_api @@ -90,14 +91,14 @@ class GraphState: # Graph structure info total_nodes: int = 0 - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) def should_continue( self, - max_node_executions: Optional[int], - execution_timeout: Optional[float], - ) -> Tuple[bool, str]: + max_node_executions: int | None, + execution_timeout: float | None, + ) -> tuple[bool, str]: """Check if the graph should continue execution. Returns: (should_continue, reason) @@ -123,7 +124,7 @@ class GraphResult(MultiAgentResult): completed_nodes: int = 0 failed_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -233,13 +234,13 @@ def __init__(self) -> None: self.entry_points: set[GraphNode] = set() # Configuration options - self._max_node_executions: Optional[int] = None - self._execution_timeout: Optional[float] = None - self._node_timeout: Optional[float] = None + self._max_node_executions: int | None = None + self._execution_timeout: float | None = None + self._node_timeout: float | None = None self._reset_on_revisit: bool = False self._id: str = _DEFAULT_GRAPH_ID - self._session_manager: Optional[SessionManager] = None - self._hooks: Optional[list[HookProvider]] = None + self._session_manager: SessionManager | None = None + self._hooks: list[HookProvider] | None = None def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" @@ -408,14 +409,14 @@ def __init__( nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode], - max_node_executions: Optional[int] = None, - execution_timeout: Optional[float] = None, - node_timeout: Optional[float] = None, + max_node_executions: int | None = None, + execution_timeout: float | None = None, + node_timeout: float | None = None, reset_on_revisit: bool = False, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_GRAPH_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..6c1149624 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -18,8 +18,9 @@ import json import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, Optional, cast from opentelemetry import trace as trace_api @@ -184,7 +185,7 @@ def should_continue( execution_timeout: float, repetitive_handoff_detection_window: int, repetitive_handoff_min_unique_agents: int, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the swarm should continue. Returns: (should_continue, reason) @@ -239,10 +240,10 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_SWARM_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Swarm with agents and configuration. diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fc80fc520..0b25d4b5d 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,7 @@ import os import shutil import tempfile -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from .. import _identifier from ..types.exceptions import SessionException @@ -44,7 +44,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): def __init__( self, session_id: str, - storage_dir: Optional[str] = None, + storage_dir: str | None = None, **kwargs: Any, ): """Initialize FileSession with filesystem storage. @@ -108,7 +108,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return cast(dict[str, Any], json.load(f)) except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e @@ -140,7 +140,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data.""" session_file = os.path.join(self._get_session_path(session_id), "session.json") if not os.path.exists(session_file): @@ -169,7 +169,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A session_data = session_agent.to_dict() self._write_file(agent_file, session_data) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data.""" agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") if not os.path.exists(agent_file): @@ -199,7 +199,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio session_dict = session_message.to_dict() self._write_file(message_file, session_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data.""" message_path = self._get_message_path(session_id, agent_id, message_id) if not os.path.exists(message_path): @@ -220,7 +220,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_file(message_file, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List messages for an agent with pagination.""" messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") @@ -269,7 +269,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_file(multi_agent_file, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from filesystem.""" multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") if not os.path.exists(multi_agent_file): diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a8ac099d9..d23c4a94f 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,7 +1,7 @@ """Repository session manager implementation.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..agent.state import AgentState from ..tools._tool_helpers import generate_missing_tool_result_content @@ -57,7 +57,7 @@ def __init__( self.session = session # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + self._latest_agent_message: dict[str, SessionMessage | None] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 7d081cf09..e5713e5b7 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -47,9 +47,9 @@ def __init__( session_id: str, bucket: str, prefix: str = "", - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -130,7 +130,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" - def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + def _read_s3_object(self, key: str) -> dict[str, Any] | None: """Read JSON object from S3.""" try: response = self.client.get_object(Bucket=self.bucket, Key=key) @@ -144,7 +144,7 @@ def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e - def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + def _write_s3_object(self, key: str, data: dict[str, Any]) -> None: """Write JSON object to S3.""" try: content = json.dumps(data, indent=2, ensure_ascii=False) @@ -171,7 +171,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: self._write_s3_object(session_key, session_dict) return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data from S3.""" session_key = f"{self._get_session_path(session_id)}session.json" session_data = self._read_s3_object(session_key) @@ -209,7 +209,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" self._write_s3_object(agent_key, agent_dict) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data from S3.""" agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" agent_data = self._read_s3_object(agent_key) @@ -236,7 +236,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio message_key = self._get_message_path(session_id, agent_id, message_id) self._write_s3_object(message_key, message_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data from S3.""" message_key = self._get_message_path(session_id, agent_id, message_id) message_data = self._read_s3_object(message_key) @@ -257,8 +257,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_s3_object(message_key, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any - ) -> List[SessionMessage]: + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any + ) -> list[SessionMessage]: """List messages for an agent with pagination from S3.""" messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: @@ -288,7 +288,7 @@ def list_messages( message_keys = message_keys[offset:] # Load only the required message objects - messages: List[SessionMessage] = [] + messages: list[SessionMessage] = [] for key in message_keys: message_data = self._read_s3_object(key) if message_data: @@ -312,7 +312,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_s3_object(multi_agent_key, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from S3.""" multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" return self._read_s3_object(multi_agent_key) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 3f5476bdf..0b6f2c705 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,7 +1,7 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..types.session import Session, SessionAgent, SessionMessage @@ -17,7 +17,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new Session.""" @abstractmethod - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read a Session.""" @abstractmethod @@ -25,7 +25,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A """Create a new Agent in a Session.""" @abstractmethod - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read an Agent.""" @abstractmethod @@ -37,7 +37,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio """Create a new Message for the Agent.""" @abstractmethod - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read a Message.""" @abstractmethod @@ -49,7 +49,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio @abstractmethod def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" @@ -57,7 +57,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k """Create a new MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 8f3ee1ea1..163df803a 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -3,8 +3,9 @@ import logging import time import uuid +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Optional import opentelemetry.metrics as metrics_api from opentelemetry.metrics import Counter, Histogram, Meter @@ -23,11 +24,11 @@ class Trace: def __init__( self, name: str, - parent_id: Optional[str] = None, - start_time: Optional[float] = None, - raw_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - message: Optional[Message] = None, + parent_id: str | None = None, + start_time: float | None = None, + raw_name: str | None = None, + metadata: dict[str, Any] | None = None, + message: Message | None = None, ) -> None: """Initialize a new trace. @@ -42,15 +43,15 @@ def __init__( """ self.id: str = str(uuid.uuid4()) self.name: str = name - self.raw_name: Optional[str] = raw_name - self.parent_id: Optional[str] = parent_id + self.raw_name: str | None = raw_name + self.parent_id: str | None = parent_id self.start_time: float = start_time if start_time is not None else time.time() - self.end_time: Optional[float] = None - self.children: List["Trace"] = [] - self.metadata: Dict[str, Any] = metadata or {} - self.message: Optional[Message] = message + self.end_time: float | None = None + self.children: list[Trace] = [] + self.metadata: dict[str, Any] = metadata or {} + self.message: Message | None = message - def end(self, end_time: Optional[float] = None) -> None: + def end(self, end_time: float | None = None) -> None: """Mark the trace as complete with the given or current timestamp. Args: @@ -67,7 +68,7 @@ def add_child(self, child: "Trace") -> None: """ self.children.append(child) - def duration(self) -> Optional[float]: + def duration(self) -> float | None: """Calculate the duration of this trace. Returns: @@ -83,7 +84,7 @@ def add_message(self, message: Message) -> None: """ self.message = message - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert the trace to a dictionary representation. Returns: @@ -127,7 +128,7 @@ def add_call( duration: float, success: bool, metrics_client: "MetricsClient", - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] | None = None, ) -> None: """Record a new tool call with its outcome. @@ -207,7 +208,7 @@ def _metrics_client(self) -> "MetricsClient": return MetricsClient() @property - def latest_agent_invocation(self) -> Optional[AgentInvocation]: + def latest_agent_invocation(self) -> AgentInvocation | None: """Get the most recent agent invocation. Returns: @@ -217,8 +218,8 @@ def latest_agent_invocation(self) -> Optional[AgentInvocation]: def start_cycle( self, - attributes: Dict[str, Any], - ) -> Tuple[float, Trace]: + attributes: dict[str, Any], + ) -> tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. Args: @@ -243,7 +244,7 @@ def start_cycle( return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: dict[str, Any] | None = None) -> None: """End the current event loop cycle and record its duration. Args: @@ -358,7 +359,7 @@ def update_metrics(self, metrics: Metrics) -> None: self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of all collected metrics. Returns: @@ -404,7 +405,7 @@ def get_summary(self) -> Dict[str, Any]: return summary -def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: Set[str]) -> Iterable[str]: +def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: set[str]) -> Iterable[str]: """Convert event loop metrics to a series of formatted text lines. Args: @@ -465,7 +466,7 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name yield from _trace_to_lines(trace.to_dict(), allowed_names=allowed_names, indent=1) -def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterable[str]: +def _trace_to_lines(trace: dict, allowed_names: set[str], indent: int) -> Iterable[str]: """Convert a trace to a series of formatted text lines. Args: @@ -497,7 +498,7 @@ def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterab yield from _trace_to_lines(child, allowed_names, indent + 1) -def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optional[Set[str]] = None) -> str: +def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: set[str] | None = None) -> str: """Convert event loop metrics to a human-readable string representation. Args: diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d16b37fc8..d73ea3c39 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,8 +7,9 @@ import json import logging import os +from collections.abc import Mapping from datetime import date, datetime, timezone -from typing import Any, Dict, Mapping, Optional, cast +from typing import Any, cast import opentelemetry.trace as trace_api from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -89,7 +90,7 @@ class Tracer: def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ - self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer_provider: trace_api.TracerProvider | None = None self.tracer_provider = trace_api.get_tracer_provider() self.tracer = self.tracer_provider.get_tracer(self.service_name) ThreadingInstrumentor().instrument() @@ -112,8 +113,8 @@ def _parse_semconv_opt_in(self) -> set[str]: def _start_span( self, span_name: str, - parent_span: Optional[Span] = None, - attributes: Optional[Dict[str, AttributeValue]] = None, + parent_span: Span | None = None, + attributes: dict[str, AttributeValue] | None = None, span_kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, ) -> Span: """Generic helper method to start a span with common attributes. @@ -145,7 +146,7 @@ def _start_span( return span - def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -159,7 +160,7 @@ def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> span.set_attribute(key, value) def _add_optional_usage_and_metrics_attributes( - self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + self, attributes: dict[str, AttributeValue], usage: Usage, metrics: Metrics ) -> None: """Add optional usage and metrics attributes if they have values. @@ -183,8 +184,8 @@ def _add_optional_usage_and_metrics_attributes( def _end_span( self, span: Span, - attributes: Optional[Dict[str, AttributeValue]] = None, - error: Optional[Exception] = None, + attributes: dict[str, AttributeValue] | None = None, + error: Exception | None = None, ) -> None: """Generic helper method to end a span. @@ -221,7 +222,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: """End a span with error status. Args: @@ -235,7 +236,7 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Optiona error = exception or Exception(error_message) self._end_span(span, error=error) - def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Attributes) -> None: + def _add_event(self, span: Span | None, event_name: str, event_attributes: Attributes) -> None: """Add an event with attributes to a span. Args: @@ -275,9 +276,9 @@ def _get_event_name_for_message(self, message: Message) -> str: def start_model_invoke_span( self, messages: Messages, - parent_span: Optional[Span] = None, - model_id: Optional[str] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + model_id: str | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a model invocation. @@ -292,7 +293,7 @@ def start_model_invoke_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") if custom_trace_attributes: attributes.update(custom_trace_attributes) @@ -315,7 +316,7 @@ def end_model_invoke_span( usage: Usage, metrics: Metrics, stop_reason: StopReason, - error: Optional[Exception] = None, + error: Exception | None = None, ) -> None: """End a model invocation span with results and metrics. @@ -327,7 +328,7 @@ def end_model_invoke_span( stop_reason (StopReason): The reason the model stopped generating. error: Optional exception if the model call failed. """ - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], @@ -366,8 +367,8 @@ def end_model_invoke_span( def start_tool_call_span( self, tool: ToolUse, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a tool call. @@ -381,7 +382,7 @@ def start_tool_call_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") attributes.update( { "gen_ai.tool.name": tool["name"], @@ -432,9 +433,7 @@ def start_tool_call_span( return span - def end_tool_call_span( - self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None - ) -> None: + def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: """End a tool call span with results. Args: @@ -442,7 +441,7 @@ def end_tool_call_span( tool_result: The result from the tool execution. error: Optional exception if the tool call failed. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if tool_result is not None: status = tool_result.get("status") status_str = str(status) if status is not None else "" @@ -490,10 +489,10 @@ def start_event_loop_cycle_span( self, invocation_state: Any, messages: Messages, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span | None: """Start a new span for an event loop cycle. Args: @@ -509,7 +508,7 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "event_loop.cycle_id": event_loop_cycle_id, } @@ -532,8 +531,8 @@ def end_event_loop_cycle_span( self, span: Span, message: Message, - tool_result_message: Optional[Message] = None, - error: Optional[Exception] = None, + tool_result_message: Message | None = None, + error: Exception | None = None, ) -> None: """End an event loop cycle span with results. @@ -543,8 +542,8 @@ def end_event_loop_cycle_span( tool_result_message: Optional tool result message if a tool was called. error: Optional exception if the cycle failed. """ - attributes: Dict[str, AttributeValue] = {} - event_attributes: Dict[str, AttributeValue] = {"message": serialize(message["content"])} + attributes: dict[str, AttributeValue] = {} + event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: event_attributes["tool.result"] = serialize(tool_result_message["content"]) @@ -572,10 +571,10 @@ def start_agent_span( self, messages: Messages, agent_name: str, - model_id: Optional[str] = None, - tools: Optional[list] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, - tools_config: Optional[dict] = None, + model_id: str | None = None, + tools: list | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + tools_config: dict | None = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -592,7 +591,7 @@ def start_agent_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") attributes.update( { "gen_ai.agent.name": agent_name, @@ -630,8 +629,8 @@ def start_agent_span( def end_agent_span( self, span: Span, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """End an agent span with results and metrics. @@ -640,7 +639,7 @@ def end_agent_span( response: The response from the agent. error: Any error that occurred. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if response: if self.use_latest_genai_conventions: @@ -702,11 +701,11 @@ def start_multiagent_span( self, task: MultiAgentInput, instance: str, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> Span: """Start a new span for swarm invocation.""" operation = f"invoke_{instance}" - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation) attributes.update( { "gen_ai.agent.name": instance, @@ -741,7 +740,7 @@ def start_multiagent_span( def end_swarm_span( self, span: Span, - result: Optional[str] = None, + result: str | None = None, ) -> None: """End a swarm span with results.""" if result: @@ -770,7 +769,7 @@ def end_swarm_span( def _get_common_attributes( self, operation_name: str, - ) -> Dict[str, AttributeValue]: + ) -> dict[str, AttributeValue]: """Returns a dictionary of common attributes based on the convention version used. Args: diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index bfec5886d..8ca6138fc 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,7 +9,8 @@ import json import random -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from .._async import run_async from ..tools.executors._executor import ToolExecutor diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 8dc933f51..f64c17ee9 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -44,16 +44,13 @@ def my_tool(param1: str, param2: int = 42) -> dict: import functools import inspect import logging +from collections.abc import Callable from typing import ( Annotated, Any, - Callable, Generic, - Optional, ParamSpec, - Type, TypeVar, - Union, cast, get_args, get_origin, @@ -183,7 +180,7 @@ def _validate_signature(self) -> None: # Found the parameter, no need to check further break - def _create_input_model(self) -> Type[BaseModel]: + def _create_input_model(self) -> type[BaseModel]: """Create a Pydantic model from function signature for input validation. This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can @@ -463,7 +460,7 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__(self, instance: Any, obj_type: type | None = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -666,20 +663,20 @@ def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... # Handle @decorator() @overload def tool( - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore - func: Optional[Callable[P, R]] = None, - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + func: Callable[P, R] | None = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5d01c5d48..6d58c5c75 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,7 +7,8 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, cast from opentelemetry import trace as trace_api diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 216eee379..7fa34eff0 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,8 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index f78e60872..dc5b9a5d9 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,7 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 6f745b728..2115cdee8 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -9,7 +9,7 @@ from pathlib import Path from posixpath import expanduser from types import ModuleType -from typing import List, cast +from typing import cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -20,7 +20,7 @@ _TOOL_MODULE_PREFIX = "_strands_tool_" -def load_tool_from_string(tool_string: str) -> List[AgentTool]: +def load_tool_from_string(tool_string: str) -> list[AgentTool]: """Load tools follows strands supported input string formats. This function can load a tool based on a string in the following ways: @@ -42,7 +42,7 @@ def load_tool_from_string(tool_string: str) -> List[AgentTool]: return load_tools_from_module_path(tool_string) -def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: +def load_tools_from_file_path(tool_path: str) -> list[AgentTool]: """Load module from specified path, and then load tools from that module. This function attempts to load the passed in path as a python module, and if it succeeds, @@ -116,7 +116,7 @@ def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTo # Try and see if any of the attributes in the module are function-based tools decorated with @tool # This means that there may be more than one tool available in this module, so we load them all - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] # Function tools will appear as attributes in the module for attr_name in dir(module): attr = getattr(module, attr_name) @@ -153,7 +153,7 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + def load_python_tools(tool_path: str, tool_name: str) -> list[AgentTool]: """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the @@ -206,7 +206,7 @@ def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: spec.loader.exec_module(module) # Collect function-based tools decorated with @tool - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c36811c17..1aff22a1e 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -14,10 +14,12 @@ import threading import uuid from asyncio import AbstractEventLoop +from collections.abc import Callable, Coroutine, Sequence from concurrent import futures from datetime import timedelta +from re import Pattern from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast +from typing import Any, TypeVar, cast import anyio from mcp import ClientSession, ListToolsResult @@ -71,7 +73,7 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] -MIME_TO_FORMAT: Dict[str, ImageFormat] = { +MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", "image/png": "png", @@ -117,7 +119,7 @@ def __init__( startup_timeout: int = 30, tool_filters: ToolFilters | None = None, prefix: str | None = None, - elicitation_callback: Optional[ElicitationFnT] = None, + elicitation_callback: ElicitationFnT | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -300,9 +302,7 @@ def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: # MCP-specific methods - def stop( - self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] - ) -> None: + def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. This method is defensive and can handle partial initialization states that may occur @@ -415,7 +415,7 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) - def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: """Synchronously retrieves the list of available prompts from the MCP server. This method calls the asynchronous list_prompts method on the MCP session @@ -463,7 +463,7 @@ async def _get_prompt_async() -> GetPromptResult: return get_prompt_result - def list_resources_sync(self, pagination_token: Optional[str] = None) -> ListResourcesResult: + def list_resources_sync(self, pagination_token: str | None = None) -> ListResourcesResult: """Synchronously retrieves the list of available resources from the MCP server. This method calls the asynchronous list_resources method on the MCP session @@ -510,7 +510,7 @@ async def _read_resource_async() -> ReadResourceResult: return read_resource_result - def list_resource_templates_sync(self, pagination_token: Optional[str] = None) -> ListResourceTemplatesResult: + def list_resource_templates_sync(self, pagination_token: str | None = None) -> ListResourceTemplatesResult: """Synchronously retrieves the list of available resource templates from the MCP server. Resource templates define URI patterns that can be used to access resources dynamically. @@ -739,7 +739,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, - ) -> Union[ToolResultContent, None]: + ) -> ToolResultContent | None: """Maps MCP content types to tool result content types. This method converts MCP-specific content types to the generic @@ -859,7 +859,7 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" return self._should_include_tool_with_filters(tool, self._tool_filters) - def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: """Check if a tool should be included based on provided filters.""" if not filters: return True diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index f8ab3bc80..d1750daa3 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -9,9 +9,10 @@ Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 """ +from collections.abc import AsyncGenerator, Callable from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Tuple +from typing import Any from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest @@ -129,7 +130,7 @@ def transport_wrapper() -> Callable[ @asynccontextmanager async def traced_method( wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[Tuple[Any, Any], None]: + ) -> AsyncGenerator[tuple[Any, Any], None]: async with wrapped(*args, **kwargs) as result: try: read_stream, write_stream = result @@ -139,7 +140,7 @@ async def traced_method( return traced_method - def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]: + def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: """Create a wrapper for MCP session initialization. Wraps session message streams to enable bidirectional context flow. @@ -151,7 +152,7 @@ def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any """ def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any] + wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: wrapped(*args, **kwargs) reader = getattr(instance, "_incoming_message_stream_reader", None) diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 2547aabcc..f9787a182 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -10,12 +10,13 @@ import sys import uuid import warnings +from collections.abc import Iterable, Sequence from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, cast -from typing_extensions import TypedDict, cast +from typing_extensions import TypedDict from .._async import run_async from ..experimental.tools import ToolProvider @@ -35,13 +36,13 @@ class ToolRegistry: def __init__(self) -> None: """Initialize the tool registry.""" - self.registry: Dict[str, AgentTool] = {} - self.dynamic_tools: Dict[str, AgentTool] = {} - self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List[ToolProvider] = [] + self.registry: dict[str, AgentTool] = {} + self.dynamic_tools: dict[str, AgentTool] = {} + self.tool_config: dict[str, Any] | None = None + self._tool_providers: list[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) - def process_tools(self, tools: List[Any]) -> List[str]: + def process_tools(self, tools: list[Any]) -> list[str]: """Process tools list. Process list of tools that can contain local file path string, module import path string, @@ -186,7 +187,7 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: logger.exception("tool_name=<%s> | failed to load tool", tool_name) raise ValueError(f"Failed to load tool {tool_name}: {exception_str}") from e - def get_all_tools_config(self) -> Dict[str, Any]: + def get_all_tools_config(self) -> dict[str, Any]: """Dynamically generate tool configuration by combining built-in and dynamic tools. Returns: @@ -305,7 +306,7 @@ def replace(self, new_tool: AgentTool) -> None: elif tool_name in self.dynamic_tools: del self.dynamic_tools[tool_name] - def get_tools_dirs(self) -> List[Path]: + def get_tools_dirs(self) -> list[Path]: """Get all tool directory paths. Returns: @@ -325,7 +326,7 @@ def get_tools_dirs(self) -> List[Path]: return tool_dirs - def discover_tool_modules(self) -> Dict[str, Path]: + def discover_tool_modules(self) -> dict[str, Path]: """Discover available tool modules in all tools directories. Returns: @@ -568,7 +569,7 @@ def get_all_tool_specs(self) -> list[ToolSpec]: A list of ToolSpecs. """ all_tools = self.get_all_tools_config() - tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] + tools: list[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools def register_dynamic_tool(self, tool: AgentTool) -> None: @@ -645,7 +646,7 @@ class NewToolDict(TypedDict): spec: ToolSpec - def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict) -> None: + def _update_tool_config(self, tool_config: dict[str, Any], new_tool: NewToolDict) -> None: """Update tool configuration with a new tool. Args: @@ -682,7 +683,7 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict tool_config["tools"].append(new_tool_entry) logger.debug("tool_name=<%s> | added new tool", new_tool_name) - def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: + def _scan_module_for_tools(self, module: Any) -> list[AgentTool]: """Scan a module for function-based tools. Args: @@ -691,7 +692,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: Returns: List of FunctionTool instances found in the module. """ - tools: List[AgentTool] = [] + tools: list[AgentTool] = [] for name, obj in inspect.getmembers(module): if isinstance(obj, DecoratedFunctionTool): diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index f33a06915..2f8dd8ca0 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -1,7 +1,7 @@ """Context management for structured output in the event loop.""" import logging -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -17,20 +17,20 @@ class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: Type[BaseModel] | None = None): + def __init__(self, structured_output_model: type[BaseModel] | None = None): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. """ self.results: dict[str, BaseModel] = {} - self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_model: type[BaseModel] | None = structured_output_model self.structured_output_tool: StructuredOutputTool | None = None self.forced_mode: bool = False self.force_attempted: bool = False self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False - self.expected_tool_name: Optional[str] = None + self.expected_tool_name: str | None = None if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) @@ -91,7 +91,7 @@ def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: return False return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) - def get_tool_spec(self) -> Optional[ToolSpec]: + def get_tool_spec(self) -> ToolSpec | None: """Get the tool specification for structured output. Returns: diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py index 25173d048..fa20f526c 100644 --- a/src/strands/tools/structured_output/structured_output_tool.py +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ValidationError from typing_extensions import override @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} +_TOOL_SPEC_CACHE: dict[type[BaseModel], ToolSpec] = {} if TYPE_CHECKING: from ._structured_output_context import StructuredOutputContext @@ -26,7 +26,7 @@ class StructuredOutputTool(AgentTool): """Tool implementation for structured output validation.""" - def __init__(self, structured_output_model: Type[BaseModel]) -> None: + def __init__(self, structured_output_model: type[BaseModel]) -> None: """Initialize a structured output tool. Args: @@ -43,7 +43,7 @@ def __init__(self, structured_output_model: Type[BaseModel]) -> None: self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") @classmethod - def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + def _get_tool_spec(cls, structured_output_model: type[BaseModel]) -> ToolSpec: """Get a cached tool spec for the given output type. Args: @@ -84,7 +84,7 @@ def tool_type(self) -> str: return "structured_output" @property - def structured_output_model(self) -> Type[BaseModel]: + def structured_output_model(self) -> type[BaseModel]: """Get the Pydantic model type for this tool. Returns: diff --git a/src/strands/tools/structured_output/structured_output_utils.py b/src/strands/tools/structured_output/structured_output_utils.py index 093d67f7c..a78ec6195 100644 --- a/src/strands/tools/structured_output/structured_output_utils.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -1,13 +1,13 @@ """Tools for converting Pydantic models to Bedrock tools.""" -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Union from pydantic import BaseModel from ...types.tools import ToolSpec -def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def _flatten_schema(schema: dict[str, Any]) -> dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. Handles required vs optional fields properly. @@ -80,11 +80,11 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: def _process_property( - prop: Dict[str, Any], - defs: Dict[str, Any], + prop: dict[str, Any], + defs: dict[str, Any], is_required: bool = False, fully_expand: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Process a property in a schema, resolving any references. Args: @@ -174,8 +174,8 @@ def _process_property( def _process_schema_object( - schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True -) -> Dict[str, Any]: + schema_obj: dict[str, Any], defs: dict[str, Any], fully_expand: bool = True +) -> dict[str, Any]: """Process a schema object, typically from $defs, to resolve all nested properties. Args: @@ -218,7 +218,7 @@ def _process_schema_object( return result -def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: +def _process_nested_dict(d: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]: """Recursively processes nested dictionaries and resolves $ref references. Args: @@ -228,7 +228,7 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A Returns: Processed dictionary """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} # Handle direct reference if "$ref" in d: @@ -258,8 +258,8 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A def convert_pydantic_to_tool_spec( - model: Type[BaseModel], - description: Optional[str] = None, + model: type[BaseModel], + description: str | None = None, ) -> ToolSpec: """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. @@ -302,7 +302,7 @@ def convert_pydantic_to_tool_spec( ) -def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _expand_nested_properties(schema: dict[str, Any], model: type[BaseModel]) -> None: """Expand the properties of nested models in the schema to include their full structure. This updates the schema in place. @@ -348,7 +348,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> schema["properties"][prop_name] = expanded_object -def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_referenced_models(schema: dict[str, Any], model: type[BaseModel]) -> None: """Process referenced models to ensure their docstrings are included. This updates the schema in place. @@ -388,7 +388,7 @@ def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) - _process_properties(ref_def, field_type) -def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_properties(schema_def: dict[str, Any], model: type[BaseModel]) -> None: """Process properties in a schema definition to add descriptions from field metadata. Args: diff --git a/src/strands/tools/watcher.py b/src/strands/tools/watcher.py index 44f2ed512..c7f50fccd 100644 --- a/src/strands/tools/watcher.py +++ b/src/strands/tools/watcher.py @@ -6,7 +6,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Set +from typing import Any from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -25,9 +25,9 @@ class ToolWatcher: # design pattern avoids conflicts when multiple tool registries are watching the same directories. _shared_observer = None - _watched_dirs: Set[str] = set() + _watched_dirs: set[str] = set() _observer_started = False - _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} + _registry_handlers: dict[str, dict[int, "ToolWatcher.ToolChangeHandler"]] = {} def __init__(self, tool_registry: ToolRegistry) -> None: """Initialize a tool watcher for the given tool registry. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index d64357cf8..0896d48e1 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,7 +5,8 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel from typing_extensions import override diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index 623f6ddc7..2b3714ce1 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,7 +3,7 @@ These types are modeled after the Bedrock API. """ -from typing import List, Literal, Union +from typing import Literal from typing_extensions import TypedDict @@ -120,13 +120,13 @@ class WebLocation(TypedDict, total=False): WebLocationDict = dict[Literal["web"], WebLocation] # Union type for citation locations - tagged union format matching AWS Bedrock API -CitationLocation = Union[ - DocumentCharLocationDict, - DocumentPageLocationDict, - DocumentChunkLocationDict, - SearchResultLocationDict, - WebLocationDict, -] +CitationLocation = ( + DocumentCharLocationDict + | DocumentPageLocationDict + | DocumentChunkLocationDict + | SearchResultLocationDict + | WebLocationDict +) class CitationSourceContent(TypedDict, total=False): @@ -178,7 +178,7 @@ class Citation(TypedDict, total=False): """ location: CitationLocation - sourceContent: List[CitationSourceContent] + sourceContent: list[CitationSourceContent] title: str @@ -196,5 +196,5 @@ class CitationsContentBlock(TypedDict, total=False): citations. """ - citations: List[Citation] - content: List[CitationGeneratedContent] + citations: list[Citation] + content: list[CitationGeneratedContent] diff --git a/src/strands/types/collections.py b/src/strands/types/collections.py index df857ace0..28b4a1891 100644 --- a/src/strands/types/collections.py +++ b/src/strands/types/collections.py @@ -1,6 +1,6 @@ """Generic collection types for the Strands SDK.""" -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -12,7 +12,7 @@ class PaginatedList(list, Generic[T]): so existing code that expects List[T] will continue to work. """ - def __init__(self, data: List[T], token: Optional[str] = None): + def __init__(self, data: list[T], token: str | None = None): """Initialize a PaginatedList with data and an optional pagination token. Args: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..d75dbb87f 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,7 +6,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -23,7 +23,7 @@ class GuardContentText(TypedDict): text: The input text details to be evaluated by the guardrail. """ - qualifiers: List[Literal["grounding_source", "query", "guard_content"]] + qualifiers: list[Literal["grounding_source", "query", "guard_content"]] text: str @@ -45,7 +45,7 @@ class ReasoningTextBlock(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - signature: Optional[str] + signature: str | None text: str @@ -120,7 +120,7 @@ class DeltaContent(TypedDict, total=False): """ text: str - toolUse: Dict[Literal["input"], str] + toolUse: dict[Literal["input"], str] class ContentBlockStartToolUse(TypedDict): @@ -142,7 +142,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Information about a tool that the model is requesting to use. """ - toolUse: Optional[ContentBlockStartToolUse] + toolUse: ContentBlockStartToolUse | None class ContentBlockDelta(TypedDict): @@ -183,9 +183,9 @@ class Message(TypedDict): role: The role of the message sender. """ - content: List[ContentBlock] + content: list[ContentBlock] role: Role -Messages = List[Message] +Messages = list[Message] """A list of messages representing a conversation.""" diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index c15ba1bea..70a7aedd5 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -22,7 +22,7 @@ class GuardrailConfig(TypedDict, total=False): guardrailIdentifier: str guardrailVersion: str - streamProcessingMode: Optional[Literal["sync", "async"]] + streamProcessingMode: Literal["sync", "async"] | None trace: Literal["enabled", "disabled"] @@ -47,7 +47,7 @@ class TopicPolicy(TypedDict): topics: The topics in the assessment. """ - topics: List[Topic] + topics: list[Topic] class ContentFilter(TypedDict): @@ -71,7 +71,7 @@ class ContentPolicy(TypedDict): filters: List of content filters to apply. """ - filters: List[ContentFilter] + filters: list[ContentFilter] class CustomWord(TypedDict): @@ -108,8 +108,8 @@ class WordPolicy(TypedDict): managedWordLists: List of managed word lists to filter. """ - customWords: List[CustomWord] - managedWordLists: List[ManagedWord] + customWords: list[CustomWord] + managedWordLists: list[ManagedWord] class PIIEntity(TypedDict): @@ -182,8 +182,8 @@ class SensitiveInformationPolicy(TypedDict): regexes: The regex queries in the assessment. """ - piiEntities: List[PIIEntity] - regexes: List[Regex] + piiEntities: list[PIIEntity] + regexes: list[Regex] class ContextualGroundingFilter(TypedDict): @@ -209,7 +209,7 @@ class ContextualGroundingPolicy(TypedDict): filters: The filter details for the guardrails contextual grounding filter. """ - filters: List[ContextualGroundingFilter] + filters: list[ContextualGroundingFilter] class GuardrailAssessment(TypedDict): @@ -239,9 +239,9 @@ class GuardrailTrace(TypedDict): outputAssessments: Assessments of output content against guardrail policies, keyed by output identifier. """ - inputAssessment: Dict[str, GuardrailAssessment] - modelOutput: List[str] - outputAssessments: Dict[str, List[GuardrailAssessment]] + inputAssessment: dict[str, GuardrailAssessment] + modelOutput: list[str] + outputAssessments: dict[str, list[GuardrailAssessment]] class Trace(TypedDict): diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..462d8af34 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -37,8 +37,8 @@ class DocumentContent(TypedDict, total=False): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource - citations: Optional[CitationsConfig] - context: Optional[str] + citations: CitationsConfig | None + context: str | None ImageFormat = Literal["png", "jpeg", "gif", "webp"] diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 5da3dcde8..29453f4b7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..interrupt import _InterruptState from .content import Message @@ -69,7 +69,7 @@ class SessionMessage: message: Message message_id: int - redact_message: Optional[Message] = None + redact_message: Message | None = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index dcfd541a8..8ec2e8d7b 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -5,8 +5,6 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Optional, Union - from typing_extensions import TypedDict from .citations import CitationLocation @@ -34,7 +32,7 @@ class ContentBlockStartEvent(TypedDict, total=False): start: Information about the content block being started. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None start: ContentBlockStart @@ -102,9 +100,9 @@ class ReasoningContentBlockDelta(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - redactedContent: Optional[bytes] - signature: Optional[str] - text: Optional[str] + redactedContent: bytes | None + signature: str | None + text: str | None class ContentBlockDelta(TypedDict, total=False): @@ -131,7 +129,7 @@ class ContentBlockDeltaEvent(TypedDict, total=False): delta: The incremental content update for the content block. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None delta: ContentBlockDelta @@ -143,7 +141,7 @@ class ContentBlockStopEvent(TypedDict, total=False): This is optional to accommodate different model providers. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None class MessageStopEvent(TypedDict, total=False): @@ -154,7 +152,7 @@ class MessageStopEvent(TypedDict, total=False): stopReason: The reason why the model stopped generating content. """ - additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]] + additionalModelResponseFields: dict | list | int | float | str | bool | None | None stopReason: StopReason @@ -168,7 +166,7 @@ class MetadataEvent(TypedDict, total=False): """ metrics: Metrics - trace: Optional[Trace] + trace: Trace | None usage: Usage @@ -203,8 +201,8 @@ class RedactContentEvent(TypedDict, total=False): """ - redactUserContentMessage: Optional[str] - redactAssistantContentMessage: Optional[str] + redactUserContentMessage: str | None + redactAssistantContentMessage: str | None class StreamEvent(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8f4dba6b1..6fc0d703c 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -7,8 +7,9 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, Literal, Protocol from typing_extensions import NotRequired, TypedDict @@ -164,11 +165,7 @@ def _interrupt_id(self, name: str) -> str: ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] -ToolChoice = Union[ - ToolChoiceAutoDict, - ToolChoiceAnyDict, - ToolChoiceToolDict, -] +ToolChoice = ToolChoiceAutoDict | ToolChoiceAnyDict | ToolChoiceToolDict """ Configuration for how the model should choose tools. @@ -201,12 +198,7 @@ class ToolFunc(Protocol): __name__: str - def __call__( - self, *args: Any, **kwargs: Any - ) -> Union[ - ToolResult, - Awaitable[ToolResult], - ]: + def __call__(self, *args: Any, **kwargs: Any) -> ToolResult | Awaitable[ToolResult]: """Function signature for Python decorated and module based tools. Returns: diff --git a/src/strands/types/traces.py b/src/strands/types/traces.py index af6188adb..c5c3aaa64 100644 --- a/src/strands/types/traces.py +++ b/src/strands/types/traces.py @@ -1,20 +1,20 @@ """Tracing type definitions for the SDK.""" -from typing import List, Mapping, Optional, Sequence, Union +from collections.abc import Mapping, Sequence -AttributeValue = Union[ - str, - bool, - float, - int, - List[str], - List[bool], - List[float], - List[int], - Sequence[str], - Sequence[bool], - Sequence[int], - Sequence[float], -] +AttributeValue = ( + str + | bool + | float + | int + | list[str] + | list[bool] + | list[float] + | list[int] + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float] +) -Attributes = Optional[Mapping[str, AttributeValue]] +Attributes = Mapping[str, AttributeValue] | None diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 091f44d06..cf17bb470 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands import Agent from strands.hooks import ( @@ -17,7 +18,7 @@ class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ AgentInitializedEvent, @@ -37,7 +38,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 727d28a48..4d18297a2 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, @@ -14,7 +15,7 @@ class MockMultiAgentHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ MultiAgentInitializedEvent, @@ -30,7 +31,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 24de958bc..f1c5cae77 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,6 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable, Sequence +from typing import Any, TypedDict, TypeVar from pydantic import BaseModel @@ -25,7 +26,7 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + def __init__(self, agent_responses: Sequence[Message | RedactionMessage]): self.agent_responses = [*agent_responses] self.index = 0 @@ -33,7 +34,7 @@ def format_chunk(self, event: Any) -> StreamEvent: return event def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: return None @@ -45,9 +46,9 @@ def update_config(self, **model_config: Any) -> None: async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, ) -> AsyncGenerator[Any, None]: pass @@ -55,9 +56,9 @@ async def structured_output( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, - tool_choice: Optional[Any] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: Any | None = None, *, system_prompt_content=None, **kwargs: Any, @@ -68,7 +69,7 @@ async def stream( self.index += 1 - def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: + def map_agent_message_to_events(self, agent_message: Message | RedactionMessage) -> Iterable[dict[str, Any]]: stop_reason: StopReason = "end_turn" yield {"messageStart": {"role": "assistant"}} if agent_message.get("redactedAssistantContent"): diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index ad1415f22..12b5af42c 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -1,6 +1,5 @@ import unittest.mock from dataclasses import dataclass -from typing import List from unittest.mock import MagicMock, Mock import pytest @@ -139,7 +138,7 @@ async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, nor @pytest.mark.asyncio async def test_invoke_callbacks_async_after_event(hook_registry, after_event): """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" - call_order: List[str] = [] + call_order: list[str] = [] def callback1(_event): call_order.append("callback1") diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 5d1f02089..1ec0a8407 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Optional, cast +from typing import cast import pytest from pydantic import BaseModel @@ -150,7 +150,7 @@ class StructuredOutputModel(BaseModel): name: str value: int - optional_field: Optional[str] = None + optional_field: str | None = None def test__init__with_structured_output(mock_metrics, simple_message: Message): diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index b679faed0..7341c714e 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -1,6 +1,5 @@ """Tests for Agent structured output functionality.""" -from typing import Optional from unittest import mock from unittest.mock import Mock, patch @@ -28,7 +27,7 @@ class ProductModel(BaseModel): title: str price: float - description: Optional[str] = None + description: str | None = None @pytest.fixture diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 72ebf01c6..5d6d6869a 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -2,7 +2,7 @@ import json import unittest.mock -from typing import Any, Dict, List +from typing import Any import boto3 import pytest @@ -32,7 +32,7 @@ def sagemaker_client(boto_session): @pytest.fixture -def endpoint_config() -> Dict[str, Any]: +def endpoint_config() -> dict[str, Any]: """Default endpoint configuration for tests.""" return { "endpoint_name": "test-endpoint", @@ -42,7 +42,7 @@ def endpoint_config() -> Dict[str, Any]: @pytest.fixture -def payload_config() -> Dict[str, Any]: +def payload_config() -> dict[str, Any]: """Default payload configuration for tests.""" return { "max_tokens": 1024, @@ -64,7 +64,7 @@ def messages() -> Messages: @pytest.fixture -def tool_specs() -> List[ToolSpec]: +def tool_specs() -> list[ToolSpec]: """Sample tool specifications for testing.""" return [ { @@ -405,8 +405,8 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages, # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response @@ -444,8 +444,8 @@ async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, me # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 8cf64a39a..963904002 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Any, List +from typing import Any import pytest @@ -266,7 +266,7 @@ def test_format_request_with_unsupported_type(model, content, content_type): class AsyncStreamWrapper: - def __init__(self, items: List[Any]): + def __init__(self, items: list[Any]): self.items = items def __aiter__(self): @@ -277,7 +277,7 @@ async def _generator(self): yield item -async def mock_streaming_response(items: List[Any]): +async def mock_streaming_response(items: list[Any]): return AsyncStreamWrapper(items) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 7e28be998..8e14c9adc 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -82,7 +82,7 @@ def test_create_session(file_manager, sample_session): assert os.path.exists(session_file) # Verify content - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) assert data["session_id"] == sample_session.session_id assert data["session_type"] == sample_session.session_type @@ -144,7 +144,7 @@ def test_create_agent(file_manager, sample_session, sample_agent): assert os.path.exists(agent_file) # Verify content - with open(agent_file, "r") as f: + with open(agent_file) as f: data = json.load(f) assert data["agent_id"] == sample_agent.agent_id assert data["state"] == sample_agent.state @@ -210,7 +210,7 @@ def test_create_message(file_manager, sample_session, sample_agent, sample_messa assert os.path.exists(message_path) # Verify content - with open(message_path, "r") as f: + with open(message_path) as f: data = json.load(f) assert data["message_id"] == sample_message.message_id @@ -439,7 +439,7 @@ def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agen assert os.path.exists(multi_agent_file) # Verify content - with open(multi_agent_file, "r") as f: + with open(multi_agent_file) as f: data = json.load(f) assert data["id"] == mock_multi_agent.id assert data["state"] == mock_multi_agent.state diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index a7eb27ca5..0f1c7ffff 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -1,7 +1,5 @@ """Tests for StructuredOutputContext class.""" -from typing import Optional - from pydantic import BaseModel, Field from strands.tools.structured_output._structured_output_context import StructuredOutputContext @@ -13,7 +11,7 @@ class SampleModel(BaseModel): name: str = Field(..., description="Name field") age: int = Field(..., description="Age field", ge=0) - email: Optional[str] = Field(None, description="Optional email field") + email: str | None = Field(None, description="Optional email field") class AnotherSampleModel(BaseModel): diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py index 66f1d465d..784a508bd 100644 --- a/tests/strands/tools/structured_output/test_structured_output_tool.py +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -1,6 +1,5 @@ """Tests for StructuredOutputTool class.""" -from typing import List, Optional from unittest.mock import MagicMock import pytest @@ -23,8 +22,8 @@ class ComplexModel(BaseModel): title: str = Field(..., description="Title field") count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") - tags: List[str] = Field(default_factory=list, description="List of tags") - metadata: Optional[dict] = Field(None, description="Optional metadata") + tags: list[str] = Field(default_factory=list, description="List of tags") + metadata: dict | None = Field(None, description="Optional metadata") class ValidationTestModel(BaseModel): diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a2a4c6213..4757e5587 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,7 +3,8 @@ """ from asyncio import Queue -from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator +from typing import Annotated, Any from unittest.mock import MagicMock import pytest @@ -267,7 +268,7 @@ async def test_tool_with_optional_params(alist): """Test tool decorator with optional parameters.""" @strands.tool - def test_tool(required: str, optional: Optional[int] = None) -> str: + def test_tool(required: str, optional: int | None = None) -> str: """Test with optional param. Args: @@ -864,7 +865,7 @@ def int_return_tool(param: str) -> int: # Define tool with Union return type @strands.tool - def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: + def union_return_tool(param: str) -> dict[str, Any] | str | None: """Tool with Union return type. Args: @@ -936,7 +937,7 @@ async def test_complex_parameter_types(alist): """Test handling of complex parameter types like nested dictionaries.""" @strands.tool - def complex_type_tool(config: Dict[str, Any]) -> str: + def complex_type_tool(config: dict[str, Any]) -> str: """Tool with complex parameter type. Args: @@ -965,7 +966,7 @@ async def test_custom_tool_result_handling(alist): """Test that a function returning a properly formatted tool result dictionary is handled correctly.""" @strands.tool - def custom_result_tool(param: str) -> Dict[str, Any]: + def custom_result_tool(param: str) -> dict[str, Any]: """Tool that returns a custom tool result dictionary. Args: @@ -1079,11 +1080,11 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: @pytest.mark.asyncio async def test_tool_complex_validation_edge_cases(alist): """Test validation of complex schema edge cases.""" - from typing import Any, Dict, Union + from typing import Any # Define a tool with a complex anyOf type that could trigger edge case handling @strands.tool - def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: + def edge_case_tool(param: dict[str, Any] | None) -> str: """Tool with complex anyOf structure. Args: @@ -1236,10 +1237,10 @@ def failing_tool(param: str) -> str: @pytest.mark.asyncio async def test_tool_with_complex_anyof_schema(alist): """Test handling of complex anyOf structures in the schema.""" - from typing import Any, Dict, List, Union + from typing import Any @strands.tool - def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]) -> str: + def complex_schema_tool(union_param: list[int] | dict[str, Any] | str | None) -> str: """Tool with a complex Union type that creates anyOf in schema. Args: @@ -1680,7 +1681,7 @@ def test_tool_decorator_annotated_optional_type(): @strands.tool def optional_annotated_tool( - required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + required: Annotated[str, "Required parameter"], optional: Annotated[str | None, "Optional parameter"] = None ) -> str: """Tool with optional annotated parameter.""" return f"{required}, {optional}" @@ -1702,7 +1703,7 @@ def test_tool_decorator_annotated_complex_types(): @strands.tool def complex_annotated_tool( - tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + tags: Annotated[list[str], "List of tag strings"], config: Annotated[dict[str, Any], "Configuration dictionary"] ) -> str: """Tool with complex annotated types.""" return f"Tags: {len(tags)}, Config: {len(config)}" diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index fe9b55334..72a53bfe6 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import pytest from pydantic import BaseModel, Field @@ -27,7 +27,7 @@ class TwoUsersWithPlanet(BaseModel): """Two users model with planet.""" user1: UserWithPlanet = Field(description="The first user") - user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + user2: UserWithPlanet | None = Field(description="The second user", default=None) # Test model with list of same type fields @@ -250,8 +250,8 @@ class NodeWithCircularRef(BaseModel): def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) expected_output = { @@ -281,8 +281,8 @@ class Family(BaseModel): def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] @@ -312,14 +312,14 @@ def test_convert_pydantic_with_items_refs(): """Test that no $refs exist after lists of different components.""" class Address(BaseModel): - postal_code: Optional[str] = None + postal_code: str | None = None class Person(BaseModel): """Complete person information.""" list_of_items: list[Address] - list_of_items_nullable: Optional[list[Address]] - list_of_item_or_nullable: list[Optional[Address]] + list_of_items_nullable: list[Address] | None + list_of_item_or_nullable: list[Address | None] tool_spec = convert_pydantic_to_tool_spec(Person) @@ -378,7 +378,7 @@ class Address(BaseModel): street: str city: str country: str - postal_code: Optional[str] = None + postal_code: str | None = None class Contact(BaseModel): address: Address diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 151f913d6..8fa1fb2b2 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -84,9 +84,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): resource=BlobResourceContents( uri="https://weather.api/data/london.json", mimeType="application/json", - blob=base64.b64encode( - '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() - ).decode(), + blob=base64.b64encode(b'{"temperature": 18, "condition": "rainy", "humidity": 85}').decode(), ), ) ] diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 5c3baeba8..298272df5 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -3,7 +3,7 @@ import os import threading import time -from typing import List, Literal +from typing import Literal import pytest from mcp import StdioServerParameters, stdio_client @@ -47,7 +47,7 @@ def generate_custom_image() -> MCPImageContent: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: - print("Error while generating custom image: {}".format(e)) + print(f"Error while generating custom image: {e}") # Prompts @mcp.prompt(description="A greeting prompt template") @@ -366,7 +366,7 @@ def test_mcp_client_embedded_resources_with_agent(): assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) -def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: +def _messages_to_content_blocks(messages: list[Message]) -> list[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..57614b97f 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -3,7 +3,7 @@ """ import os -from typing import Callable, Optional +from collections.abc import Callable import requests from pytest import mark @@ -26,7 +26,7 @@ def __init__( self, id: str, factory: Callable[[], Model], - environment_variable: Optional[str] = None, + environment_variable: str | None = None, ) -> None: self.id = id self.model_factory = factory diff --git a/tests_integ/test_function_tools.py b/tests_integ/test_function_tools.py index 835dccf5d..6c72bdddb 100644 --- a/tests_integ/test_function_tools.py +++ b/tests_integ/test_function_tools.py @@ -4,7 +4,6 @@ """ import logging -from typing import Optional from strands import Agent, tool @@ -25,7 +24,7 @@ def word_counter(text: str) -> str: @tool(name="count_chars", description="Count characters in text") -def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: +def count_chars(text: str, include_spaces: bool | None = True) -> str: """ Count characters in text. diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..b80a0f82d 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from unittest.mock import patch from uuid import uuid4 diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 390bd3cff..01d3c80b2 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -2,8 +2,6 @@ Comprehensive integration tests for structured output passed into the agent functionality. """ -from typing import List, Optional - import pytest from pydantic import BaseModel, Field, field_validator @@ -42,7 +40,7 @@ class Contact(BaseModel): """Contact information.""" email: str - phone: Optional[str] = None + phone: str | None = None preferred_method: str = "email" @@ -54,7 +52,7 @@ class Employee(BaseModel): department: str address: Address contact: Contact - skills: List[str] + skills: list[str] hire_date: str salary_range: str @@ -65,7 +63,7 @@ class ProductReview(BaseModel): product_name: str rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") sentiment: str = Field(pattern="^(positive|negative|neutral)$") - key_points: List[str] + key_points: list[str] would_recommend: bool @@ -84,7 +82,7 @@ class TaskList(BaseModel): """Task management structure.""" project_name: str - tasks: List[str] + tasks: list[str] priority: str = Field(pattern="^(high|medium|low)$") due_date: str estimated_hours: int @@ -102,7 +100,7 @@ class Company(BaseModel): name: str = Field(description="Company name") address: Address = Field(description="Company address") - employees: List[Person] = Field(description="list of persons") + employees: list[Person] = Field(description="list of persons") class Task(BaseModel): From c23090f013d475bd93f49ae290899fcda5538f53 Mon Sep 17 00:00:00 2001 From: Masashi Tomooka Date: Fri, 16 Jan 2026 01:41:15 +0900 Subject: [PATCH 063/279] fix(agent): extract text from citationsContent in AgentResult.__str__ (#1489) AgentResult.__str__ now correctly extracts text from citationsContent blocks. Previously, only plain text blocks were processed, causing citation responses to return empty strings when converted to str(). Co-authored-by: Claude Opus 4.5 --- src/strands/agent/agent_result.py | 11 ++++- tests/strands/agent/test_agent_result.py | 58 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index 2ab95e5b5..8f9241a67 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -49,8 +49,15 @@ def __str__(self) -> str: result = "" for item in content_array: - if isinstance(item, dict) and "text" in item: - result += item.get("text", "") + "\n" + if isinstance(item, dict): + if "text" in item: + result += item.get("text", "") + "\n" + elif "citationsContent" in item: + citations_block = item["citationsContent"] + if "content" in citations_block: + for content in citations_block["content"]: + if isinstance(content, dict) and "text" in content: + result += content.get("text", "") + "\n" if not result and self.structured_output: result = self.structured_output.model_dump_json() diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 1ec0a8407..6e4c2c91a 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -225,3 +225,61 @@ def test__str__empty_message_with_structured_output(mock_metrics, empty_message: assert "example" in message_string assert "123" in message_string assert "optional" in message_string + + +@pytest.fixture +def citations_message(): + """Message with citationsContent block.""" + return { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "title": "Source Document", + "location": {"document": {"pageNumber": 1}}, + "sourceContent": [{"text": "source text"}], + } + ], + "content": [{"text": "This is cited text from the document."}], + } + } + ], + } + + +@pytest.fixture +def mixed_text_and_citations_message(): + """Message with both plain text and citationsContent blocks.""" + return { + "role": "assistant", + "content": [ + {"text": "Introduction paragraph"}, + { + "citationsContent": { + "citations": [{"title": "Doc", "location": {}, "sourceContent": []}], + "content": [{"text": "Cited content here."}], + } + }, + {"text": "Conclusion paragraph"}, + ], + } + + +def test__str__with_citations_content(mock_metrics, citations_message: Message): + """Test that str() extracts text from citationsContent blocks.""" + result = AgentResult(stop_reason="end_turn", message=citations_message, metrics=mock_metrics, state={}) + + message_string = str(result) + assert message_string == "This is cited text from the document.\n" + + +def test__str__mixed_text_and_citations_content(mock_metrics, mixed_text_and_citations_message: Message): + """Test that str() works with both plain text and citationsContent blocks.""" + result = AgentResult( + stop_reason="end_turn", message=mixed_text_and_citations_message, metrics=mock_metrics, state={} + ) + + message_string = str(result) + assert message_string == "Introduction paragraph\nCited content here.\nConclusion paragraph\n" From dfe3ec75d7e414b27a13798edcb51edff1e82f21 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 15 Jan 2026 14:52:35 -0500 Subject: [PATCH 064/279] Expose input messages to BeforeInvocationEvent hook (#1474) * feat(hooks): expose input messages to BeforeInvocationEvent Add messages attribute to BeforeInvocationEvent to enable input-side guardrails for PII detection, content moderation, and prompt attack prevention. Hooks can now inspect and modify messages before they are added to the agent's conversation history. - Add writable messages attribute to BeforeInvocationEvent (None default) - Pass messages parameter from _run_loop() to BeforeInvocationEvent - Add unit tests for new messages attribute and writability - Add integration tests for message modification use case - Update docs/HOOKS.md with input guardrails documentation Resolves #8 * refactor: address review feedback - Remove detailed Input Guardrails section from docs/HOOKS.md - Simplify BeforeInvocationEvent docstring per review - Remove backward compatibility note from messages attribute - Remove no-op test for messages initialization * refactor: simplify test assertions per review Use concise equality comparison for BeforeInvocationEvent assertions instead of verbose instance checks and property assertions. * Use overwritten messages array for the agent * Fix mypy issue --------- Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/agent.py | 5 +- src/strands/hooks/events.py | 11 ++- tests/strands/agent/hooks/test_events.py | 38 +++++++++- tests/strands/agent/test_agent_hooks.py | 96 +++++++++++++++++++++++- 4 files changed, 143 insertions(+), 7 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index b58b55f24..7b9e9c914 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -637,7 +637,10 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, messages=messages) + ) + messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages agent_result: AgentResult | None = None try: diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 340b6d3d2..8aa8a68d6 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult -from ..types.content import Message +from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse @@ -43,9 +43,16 @@ class BeforeInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + messages: The input messages for this invocation. Can be modified by hooks + to redact or transform content before processing. """ - pass + messages: Messages | None = None + + def _can_write(self, name: str) -> bool: + return name == "messages" @dataclass diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 9203478b2..83cb1af24 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -11,7 +11,7 @@ BeforeToolCallEvent, MessageAddedEvent, ) -from strands.types.content import Message +from strands.types.content import Message, Messages from strands.types.tools import ToolResult, ToolUse @@ -20,6 +20,11 @@ def agent(): return Mock() +@pytest.fixture +def sample_messages() -> Messages: + return [{"role": "user", "content": [{"text": "Hello, agent!"}]}] + + @pytest.fixture def tool(): tool = Mock() @@ -52,6 +57,11 @@ def start_request_event(agent): return BeforeInvocationEvent(agent=agent) +@pytest.fixture +def start_request_event_with_messages(agent, sample_messages): + return BeforeInvocationEvent(agent=agent, messages=sample_messages) + + @pytest.fixture def messaged_added_event(agent): return MessageAddedEvent(agent=agent, message=Mock()) @@ -159,3 +169,29 @@ def test_after_invocation_event_properties_not_writable(agent): with pytest.raises(AttributeError, match="Property agent is not writable"): event.agent = Mock() + + +def test_before_invocation_event_messages_default_none(agent): + """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" + event = BeforeInvocationEvent(agent=agent) + assert event.messages is None + + +def test_before_invocation_event_messages_writable(agent, sample_messages): + """Test that BeforeInvocationEvent.messages can be modified in-place for guardrail redaction.""" + event = BeforeInvocationEvent(agent=agent, messages=sample_messages) + + # Should be able to modify the messages list in-place + event.messages[0]["content"] = [{"text": "[REDACTED]"}] + assert event.messages[0]["content"] == [{"text": "[REDACTED]"}] + + # Should be able to reassign messages entirely + new_messages: Messages = [{"role": "user", "content": [{"text": "Different message"}]}] + event.messages = new_messages + assert event.messages == new_messages + + +def test_before_invocation_event_agent_not_writable(start_request_event_with_messages): + """Test that BeforeInvocationEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + start_request_event_with_messages.agent = Mock() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 00b9d368a..be71b5fcf 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -160,7 +160,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -214,7 +214,11 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m """Verify that the correct hook events are emitted as part of stream_async.""" iterator = agent.stream_async("test message") await anext(iterator) - assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] + + # Verify first event is BeforeInvocationEvent with messages + assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].messages is not None + assert hook_provider.events_received[0].messages[0]["role"] == "user" # iterate the rest result = None @@ -226,7 +230,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -596,3 +600,89 @@ async def handle_after_model_call(event: AfterModelCallEvent): # Should succeed after: custom retry + 2 throttle retries assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Success after mixed retries" + + +def test_before_invocation_event_message_modification(): + """Test that hooks can modify messages in BeforeInvocationEvent for input guardrails.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your redacted message"}], + }, + ] + ) + + modified_content = None + + async def input_guardrail_hook(event: BeforeInvocationEvent): + """Simulates a guardrail that redacts sensitive content.""" + nonlocal modified_content + if event.messages is not None: + for message in event.messages: + if message.get("role") == "user": + content = message.get("content", []) + for block in content: + if "text" in block and "SECRET" in block["text"]: + # Redact sensitive content in-place + block["text"] = block["text"].replace("SECRET", "[REDACTED]") + modified_content = event.messages[0]["content"][0]["text"] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, input_guardrail_hook) + + agent("My password is SECRET123") + + # Verify the message was modified before being processed + assert modified_content == "My password is [REDACTED]123" + # Verify the modified message was added to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "My password is [REDACTED]123" + + +def test_before_invocation_event_message_overwrite(): + """Test that hooks can overwrite messages in BeforeInvocationEvent.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your message message"}], + }, + ] + ) + + async def overwrite_input_hook(event: BeforeInvocationEvent): + event.messages = [{"role": "user", "content": [{"text": "GOODBYE"}]}] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, overwrite_input_hook) + + agent("HELLO") + + # Verify the message was overwritten to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "GOODBYE" + + +@pytest.mark.asyncio +async def test_before_invocation_event_messages_none_in_structured_output(agenerator): + """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" + + class Person(BaseModel): + name: str + age: int + + mock_provider = MockedModelProvider([]) + mock_provider.structured_output = Mock(return_value=agenerator([{"output": Person(name="Test", age=30)}])) + + received_messages = "not_set" + + async def capture_messages_hook(event: BeforeInvocationEvent): + nonlocal received_messages + received_messages = event.messages + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, capture_messages_hook) + + await agent.structured_output_async(Person, "Test prompt") + + # structured_output_async uses deprecated path that doesn't pass messages + assert received_messages is None From 058c03a487e53e4162136376232cbe5724e7c90c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 15 Jan 2026 14:59:47 -0500 Subject: [PATCH 065/279] interrupts - graph - hook based (#1478) --- src/strands/multiagent/graph.py | 135 +++++++++++-- tests/strands/multiagent/test_graph.py | 88 ++++++++- .../interrupts/multiagent/test_hook.py | 187 +++++++++++++++++- .../interrupts/multiagent/test_session.py | 98 ++++++++- 4 files changed, 475 insertions(+), 33 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 19504ad73..97435ad4a 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -35,11 +35,13 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -64,10 +66,15 @@ class GraphState: status: Current execution status of the graph. completed_nodes: Set of nodes that have completed execution. failed_nodes: Set of nodes that failed during execution. + interrupted_nodes: Set of nodes that user interrupted during execution. execution_order: List of nodes in the order they were executed. task: The original input prompt/query provided to the graph execution. This represents the actual work to be performed by the graph as a whole. Entry point nodes receive this task as their input if they have no dependencies. + start_time: Timestamp when the current invocation started. + Resets on each invocation, even when resuming from interrupt. + execution_time: Execution time of current invocation in milliseconds. + Excludes time spent waiting for interrupt responses. """ # Task (with default empty string) @@ -77,6 +84,7 @@ class GraphState: status: Status = Status.PENDING completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) + interrupted_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) start_time: float = field(default_factory=time.time) @@ -109,7 +117,7 @@ def should_continue( # Check timeout (only if set) if execution_timeout is not None: - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -123,6 +131,7 @@ class GraphResult(MultiAgentResult): total_nodes: int = 0 completed_nodes: int = 0 failed_nodes: int = 0 + interrupted_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -149,13 +158,7 @@ def should_traverse(self, state: GraphState) -> bool: @dataclass class GraphNode: - """Represents a node in the graph. - - The execution_status tracks the node's lifecycle within graph orchestration: - - PENDING: Node hasn't started executing yet - - EXECUTING: Node is currently running - - COMPLETED/FAILED: Node finished executing (regardless of result quality) - """ + """Represents a node in the graph.""" node_id: str executor: Agent | MultiAgentBase @@ -446,6 +449,7 @@ def __init__( self.node_timeout = node_timeout self.reset_on_revisit = reset_on_revisit self.state = GraphState() + self._interrupt_state = _InterruptState() self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) self.session_manager = session_manager @@ -520,6 +524,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final graph result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -529,7 +535,7 @@ async def stream_async( # Initialize state start_time = time.time() - if not self._resume_from_session: + if not self._resume_from_session and not self._interrupt_state.activated: # Initialize state self.state = GraphState( status=Status.EXECUTING, @@ -545,6 +551,8 @@ async def stream_async( span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: logger.debug( "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", @@ -554,6 +562,9 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts.extend(event.interrupts) + yield event.as_dict() # Set final status based on execution results @@ -565,7 +576,7 @@ async def stream_async( logger.debug("status=<%s> | graph execution completed", self.state.status) # Yield final result (consistent with Agent's AgentResultEvent format) - result = self._build_result() + result = self._build_result(interrupts) # Use the same event format as Agent for consistency yield MultiAgentResultEvent(result=result).as_dict() @@ -575,7 +586,7 @@ async def stream_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time += round((time.time() - start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -592,9 +603,41 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) + def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + + node.execution_status = Status.INTERRUPTED + + self.state.status = Status.INTERRUPTED + self.state.interrupted_nodes.add(node) + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) + if self._interrupt_state.activated: + ready_nodes = [self.nodes[node_id] for node_id in self._interrupt_state.context["completed_nodes"]] + ready_nodes.extend(self.state.interrupted_nodes) + + self.state.interrupted_nodes.clear() + + elif self._resume_from_session: + ready_nodes = self._resume_next_nodes + + else: + ready_nodes = list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -614,6 +657,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in self._execute_nodes_parallel(current_batch, invocation_state): yield event + if self.state.status == Status.INTERRUPTED: + self._interrupt_state.context["completed_nodes"] = [ + node.node_id for node in current_batch if node.execution_status == Status.COMPLETED + ] + return + + self._interrupt_state.deactivate() + # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here @@ -642,6 +693,9 @@ async def _execute_nodes_parallel( Uses a shared queue where each node's stream runs independently and pushes events as they occur, enabling true real-time event propagation without round-robin delays. """ + if self._interrupt_state.activated: + nodes = [node for node in nodes if node.execution_status == Status.INTERRUPTED] + event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks @@ -798,12 +852,16 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) yield start_event - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, node.node_id, invocation_state) ) start_time = time.time() try: + if interrupts: + yield self._activate_interrupt(node, interrupts) + return + if before_event.cancel_node: cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" @@ -831,6 +889,13 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") + if multi_agent_result.status == Status.INTERRUPTED: + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from a multi agent node" + ) + node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, @@ -855,12 +920,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Check for interrupt (from main branch) if agent_response.stop_reason == "interrupt": node.executor.messages.pop() # remove interrupted tool use message node.executor._interrupt_state.deactivate() - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from an agent node" + ) # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) @@ -1007,8 +1075,15 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_input - def _build_result(self) -> GraphResult: - """Build graph result from current state.""" + def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: + """Build graph result from current state. + + Args: + interrupts: List of interrupts collected during execution. + + Returns: + GraphResult with current state. + """ return GraphResult( status=self.state.status, results=self.state.results, @@ -1019,9 +1094,11 @@ def _build_result(self) -> GraphResult: total_nodes=self.state.total_nodes, completed_nodes=len(self.state.completed_nodes), failed_nodes=len(self.state.failed_nodes), + interrupted_nodes=len(self.state.interrupted_nodes), execution_order=self.state.execution_order, edges=self.state.edges, entry_points=self.state.entry_points, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: @@ -1034,10 +1111,14 @@ def serialize_state(self) -> dict[str, Any]: "status": self.state.status.value, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "failed_nodes": [n.node_id for n in self.state.failed_nodes], + "interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -1053,6 +1134,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1099,10 +1184,20 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.state.failed_nodes = set( self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes ) + for node in self.state.failed_nodes: + node.execution_status = Status.FAILED - # Restore completed nodes from persisted data - completed_node_ids = payload.get("completed_nodes") or [] - self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + self.state.interrupted_nodes = set( + self.nodes[node_id] for node_id in (payload.get("interrupted_nodes") or []) if node_id in self.nodes + ) + for node in self.state.interrupted_nodes: + node.execution_status = Status.INTERRUPTED + + self.state.completed_nodes = set( + self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes + ) + for node in self.state.completed_nodes: + node.execution_status = Status.COMPLETED # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..ab2d86e70 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest @@ -9,6 +9,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager @@ -2004,6 +2005,9 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): state = graph.serialize_state() assert state["type"] == "graph" assert state["id"] == "default_graph" + assert state["_internal_state"] == { + "interrupt_state": {"activated": False, "context": {}, "interrupts": {}}, + } assert "status" in state assert "completed_nodes" in state assert "node_results" in state @@ -2013,14 +2017,33 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "status": "executing", "completed_nodes": [], "failed_nodes": [], + "interrupted_nodes": [], "node_results": {}, "current_task": "persisted task", "execution_order": [], "next_nodes_to_execute": ["test_node"], + "_internal_state": { + "interrupt_state": { + "activated": False, + "context": {"a": 1}, + "interrupts": { + "i1": { + "id": "i1", + "name": "test_name", + "reason": "test_reason", + }, + }, + }, + }, } graph.deserialize_state(persisted_state) assert graph.state.task == "persisted task" + assert graph._interrupt_state == _InterruptState( + activated=False, + context={"a": 1}, + interrupts={"i1": Interrupt(id="i1", name="test_name", reason="test_reason")}, + ) # Execute graph to test persistence integration result = await graph.invoke_async("Test persistence") @@ -2068,3 +2091,66 @@ def cancel_callback(event): tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_graph_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_hook_providers([interrupt_hook]) + graph = builder.build() + + multiagent_result = graph("Test task") + + first_execution_time = multiagent_result.execution_time + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + assert multiagent_result.execution_time >= first_execution_time diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index be7682082..9350b3535 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -7,7 +7,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -18,16 +18,34 @@ def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) def interrupt(self, event): - if event.node_id == "info": + if event.node_id == "info" or event.node_id == "time": return - response = event.interrupt("test_interrupt", reason="need approval") + response = event.interrupt(f"{event.node_id}_interrupt", reason="need approval") if response != "APPROVE": event.cancel_node = "node rejected" return Hook() +@pytest.fixture +def day_tool(): + @tool(name="day_tool") + def func(): + return "monday" + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool") @@ -38,13 +56,49 @@ def func(): @pytest.fixture -def swarm(interrupt_hook, weather_tool): - info_agent = Agent(name="info") - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + + +@pytest.fixture +def swarm(interrupt_hook, info_agent, weather_agent): return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) +@pytest.fixture +def graph(interrupt_hook, info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + builder.set_hook_providers([interrupt_hook]) + + return builder.build() + + def test_swarm_interrupt(swarm): multiagent_result = swarm("What is the weather?") @@ -56,7 +110,7 @@ def test_swarm_interrupt(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -97,7 +151,7 @@ async def test_swarm_interrupt_reject(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -131,3 +185,120 @@ async def test_swarm_interrupt_reject(swarm): tru_node_id = multiagent_result.node_history[0].node_id exp_node_id = "info" assert tru_node_id == exp_node_id + + +def test_graph_interrupt(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + } + for interrupt in multiagent_result.interrupts + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_graph_interrupt_reject(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "REJECT", + }, + }, + ] + + try: + async for event in graph.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + except RuntimeError as e: + assert "node rejected" in str(e) + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_state_status = graph.state.status + exp_state_status = Status.FAILED + assert tru_state_status == exp_state_status diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index d6e8cdbf8..bab4b428f 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,13 +4,30 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.session import FileSessionManager from strands.types.tools import ToolContext +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "time": + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_node = "node rejected" + + return Hook() + + @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) @@ -22,9 +39,12 @@ def func(tool_context: ToolContext) -> str: @pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) - return Swarm([weather_agent]) +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func def test_swarm_interrupt_session(weather_tool, tmpdir): @@ -75,3 +95,73 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): summarizer_message = json.dumps(summarizer_result.result.message).lower() assert "sunny" in summarizer_message + + +def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + multiagent_result = graph("Can you check the time and then summarize the results?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 2 + summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() + assert "12:01" in summarizer_message From bb3052b20f4534cb523219fb7146810227e2ea21 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 15 Jan 2026 15:34:15 -0500 Subject: [PATCH 066/279] fix: Swap sleeps with explicit signaling (#1497) So that unit tests are determistic Co-authored-by: Mackenzie Zastrow --- tests/strands/agent/test_agent.py | 59 +++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 81ce65989..eb039185c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1,14 +1,13 @@ -import asyncio import copy import importlib import json import os import textwrap import threading -import time import unittest.mock import warnings -from typing import Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from uuid import uuid4 import pytest @@ -193,11 +192,25 @@ class User(BaseModel): return User(name="Jane Doe", age=30, email="jane@doe.com") -class SlowMockedModel(MockedModelProvider): +class SyncEventMockedModel(MockedModelProvider): + """A mock model that uses events to synchronize concurrent threads. + + This model signals when it starts streaming and waits for a proceed signal, + allowing deterministic testing of concurrent behavior without relying on sleeps. + """ + + def __init__(self, agent_responses): + super().__init__(agent_responses) + self.started_event = threading.Event() + self.proceed_event = threading.Event() + async def stream( self, messages, tool_specs=None, system_prompt=None, tool_choice=None, **kwargs ) -> AsyncGenerator[Any, None]: - await asyncio.sleep(0.15) # Add async delay to ensure concurrency + # Signal that streaming has started + self.started_event.set() + # Wait for signal to proceed + self.proceed_event.wait() async for event in super().stream(messages, tool_specs, system_prompt, tool_choice, **kwargs): yield event @@ -2212,7 +2225,7 @@ def test_agent_skips_fix_for_valid_conversation(mock_model, agenerator): def test_agent_concurrent_call_raises_exception(): """Test that concurrent __call__() calls raise ConcurrencyException.""" - model = SlowMockedModel( + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "hello"}]}, {"role": "assistant", "content": [{"text": "world"}]}, @@ -2233,12 +2246,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join() @@ -2254,11 +2275,12 @@ def test_agent_concurrent_structured_output_raises_exception(): Note: This test validates that the sync invocation path is protected. The concurrent __call__() test already validates the core functionality. """ - model = SlowMockedModel( + # Events for synchronization + model = SyncEventMockedModel( [ {"role": "assistant", "content": [{"text": "response1"}]}, {"role": "assistant", "content": [{"text": "response2"}]}, - ] + ], ) agent = Agent(model=model) @@ -2275,13 +2297,20 @@ def invoke(): with lock: errors.append(e) - # Create two threads that will try to invoke concurrently + # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) - t2 = threading.Thread(target=invoke) - t1.start() - time.sleep(0.05) # Small delay to ensure first thread acquires lock + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) t2.start() + + # Give second thread time to attempt invocation and fail + t2.join(timeout=1.0) + + # Now let first thread complete + model.proceed_event.set() t1.join() t2.join() From 25c46a1011cf296342d9a5855f3bf631147601b5 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:41:03 -0500 Subject: [PATCH 067/279] Fix PEP 563 incompatibility with @tool decorated tools (#1494) Fixes the incompatibility between strands-agents 1.16.0+ and Pydantic 2.12+ when tools use modules with from __future__ import annotations (PEP 563) which causes type annotations to be strings --------- Co-authored-by: strands-coder Co-authored-by: Mackenzie Zastrow --- src/strands/tools/decorator.py | 16 ++- tests/strands/tools/test_decorator_pep563.py | 142 +++++++++++++++++++ 2 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 tests/strands/tools/test_decorator_pep563.py diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index f64c17ee9..f72a8ccf1 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -98,7 +98,7 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - """ self.func = func self.signature = inspect.signature(func) - self.type_hints = get_type_hints(func) + self.type_hints = get_type_hints(func, include_extras=True) self._context_param = context_param self._validate_signature() @@ -198,9 +198,17 @@ def _create_input_model(self) -> type[BaseModel]: if self._is_special_parameter(name): continue - # Use param.annotation directly to get the raw type hint. Using get_type_hints() - # can cause inconsistent behavior across Python versions for complex Annotated types. - param_type = param.annotation + # Handle PEP 563 (from __future__ import annotations): + # - When PEP 563 is active, param.annotation is a string literal that needs resolution + # - When PEP 563 is not active, param.annotation is the actual type object (may include Annotated) + # We check if param.annotation is a string to determine if we need type hint resolution. + # This preserves Annotated metadata correctly in both cases and is consistent across Python versions. + if isinstance(param.annotation, str): + # PEP 563 active: resolve string annotation + param_type = self.type_hints.get(name, param.annotation) + else: + # PEP 563 not active: use the actual type object directly + param_type = param.annotation if param_type is inspect.Parameter.empty: param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py new file mode 100644 index 000000000..07ec8f2ba --- /dev/null +++ b/tests/strands/tools/test_decorator_pep563.py @@ -0,0 +1,142 @@ +"""Tests for PEP 563 (from __future__ import annotations) compatibility. + +This module tests that the @tool decorator works correctly when modules use +`from __future__ import annotations` (PEP 563), which causes all annotations +to be stored as string literals rather than evaluated types. + +This is a regression test for issue #1208: +https://github.com/strands-agents/sdk-python/issues/1208 +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from typing_extensions import Literal, TypedDict + +from strands import tool + +# Define types at module level (simulating nova-act pattern) +CLICK_TYPE = Literal["left", "right", "middle", "double"] +EXTRA_TYPE = Literal["extra"] + + +class ClickOptions(TypedDict): + """Options for click operation.""" + + blur_field: bool | None + + +@tool +def simple_literal_tool(click_type: CLICK_TYPE) -> dict[str, Any]: + return {"status": "success", "content": [{"text": f"Clicked: {click_type}"}]} + + +@tool +def complex_literal_tool( + box: str, + extra: EXTRA_TYPE, + click_type: CLICK_TYPE | None = None, + click_options: ClickOptions | None = None, +) -> Any: + return "Done" + + +@tool +def union_literal_tool(mode: Literal["fast", "slow"] | None = None) -> str: + return f"Mode: {mode}" + + +def test_simple_literal_type_tool_spec(): + """Test that simple Literal type parameters work with __future__ annotations.""" + spec = simple_literal_tool.tool_spec + assert spec["name"] == "simple_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "click_type" in schema["properties"] + # Verify Literal values are present in schema + click_type_schema = schema["properties"]["click_type"] + assert "enum" in click_type_schema or "anyOf" in click_type_schema + + +def test_complex_literal_type_tool_spec(): + """Test that complex type hints with Literal work with __future__ annotations.""" + spec = complex_literal_tool.tool_spec + assert spec["name"] == "complex_literal_tool" + + schema = spec["inputSchema"]["json"] + # Ensure schema is correct and contains the expected shape + assert schema == { + "$defs": { + "ClickOptions": { + "description": "Options for click operation.", + "properties": {"blur_field": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "title": "Blur Field"}}, + "required": ["blur_field"], + "title": "ClickOptions", + "type": "object", + } + }, + "properties": { + "box": {"description": "Parameter box", "type": "string"}, + "click_options": { + "$ref": "#/$defs/ClickOptions", + "default": None, + "description": "Parameter click_options", + }, + "click_type": { + "default": None, + "description": "Parameter click_type", + "enum": ["left", "right", "middle", "double"], + "type": "string", + }, + "extra": {"const": "extra", "description": "Parameter extra", "type": "string"}, + }, + "required": ["box", "extra"], + "type": "object", + } + + +def test_union_literal_tool_spec(): + """Test that inline Literal in Union works with __future__ annotations.""" + spec = union_literal_tool.tool_spec + assert spec["name"] == "union_literal_tool" + + schema = spec["inputSchema"]["json"] + assert "mode" in schema["properties"] + + +def test_simple_literal_tool_invocation(): + """Test that tools with Literal types can be invoked.""" + result = simple_literal_tool(click_type="left") + assert result["status"] == "success" + assert "left" in result["content"][0]["text"] + + +def test_complex_literal_tool_invocation(): + """Test that tools with complex types can be invoked.""" + result = complex_literal_tool( + box="box1", + extra="extra", + click_type="double", + click_options={"blur_field": True}, + ) + assert result == "Done" + + +def test_tool_spec_no_pydantic_error(): + """Verify no PydanticUserError is raised when accessing tool_spec. + + This is the specific error from issue #1208: + PydanticUserError: `Agent_clickTool` is not fully defined; + you should define `EXTRA_TYPE`, then call `Agent_clickTool.model_rebuild()`. + """ + # This should not raise PydanticUserError + try: + _ = simple_literal_tool.tool_spec + _ = complex_literal_tool.tool_spec + _ = union_literal_tool.tool_spec + except Exception as e: + if "not fully defined" in str(e): + pytest.fail(f"PydanticUserError raised - PEP 563 compatibility broken: {e}") + raise From 5e733ef00b5162ed97b662ad6a9ed4f9c72ced21 Mon Sep 17 00:00:00 2001 From: okamototk Date: Fri, 16 Jan 2026 06:40:39 -0800 Subject: [PATCH 068/279] feat: override service name by OTEL_SERVICE_NAME env (#1400) --- src/strands/telemetry/config.py | 6 +++++- tests/strands/telemetry/test_config.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/strands/telemetry/config.py b/src/strands/telemetry/config.py index 0509c7440..93225335d 100644 --- a/src/strands/telemetry/config.py +++ b/src/strands/telemetry/config.py @@ -5,6 +5,7 @@ """ import logging +import os from importlib.metadata import version from typing import Any @@ -29,9 +30,11 @@ def get_otel_resource() -> Resource: Returns: Resource object with standard service information. """ + service_name = os.getenv("OTEL_SERVICE_NAME", "strands-agents").strip() + resource = Resource.create( { - "service.name": "strands-agents", + "service.name": service_name, "service.version": version("strands-agents"), "telemetry.sdk.name": "opentelemetry", "telemetry.sdk.language": "python", @@ -56,6 +59,7 @@ class StrandsTelemetry: Environment variables are handled by the underlying OpenTelemetry SDK: - OTEL_EXPORTER_OTLP_ENDPOINT: OTLP endpoint URL - OTEL_EXPORTER_OTLP_HEADERS: Headers for OTLP requests + - OTEL_SERVICE_NAME: Overrides resource service name Examples: Quick setup with method chaining: diff --git a/tests/strands/telemetry/test_config.py b/tests/strands/telemetry/test_config.py index 658d4d08a..cc08c295c 100644 --- a/tests/strands/telemetry/test_config.py +++ b/tests/strands/telemetry/test_config.py @@ -2,6 +2,7 @@ import pytest +import strands.telemetry.config as telemetry_config from strands.telemetry import StrandsTelemetry @@ -212,3 +213,21 @@ def test_setup_otlp_exporter_exception(mock_resource, mock_tracer_provider, mock telemetry.setup_otlp_exporter() mock_otlp_exporter.assert_called_once() + + +def test_get_otel_resource_uses_default_service_name(monkeypatch): + monkeypatch.delenv("OTEL_SERVICE_NAME", raising=False) + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "strands-agents" + + +def test_get_otel_resource_respects_otel_service_name(monkeypatch): + monkeypatch.setenv("OTEL_SERVICE_NAME", "my-service") + monkeypatch.setattr(telemetry_config, "version", lambda _: "0.0.0") + + resource = telemetry_config.get_otel_resource() + + assert resource.attributes.get("service.name") == "my-service" From bce2464b4aaf6699eaa5fc1d0f78ac7cfcbc6e73 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Fri, 16 Jan 2026 10:04:56 -0500 Subject: [PATCH 069/279] fix(bedrock): disable thinking mode when forcing tool_choice (#1495) --------- Co-authored-by: Dean Schmigelski --- src/strands/models/bedrock.py | 34 ++++++-- tests/strands/models/test_bedrock_thinking.py | 84 +++++++++++++++++++ tests_integ/models/test_model_bedrock.py | 37 ++++++++ 3 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 tests/strands/models/test_bedrock_thinking.py diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index dfcd133c6..567a2e147 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -255,11 +255,7 @@ def _format_request( if tool_specs else {} ), - **( - {"additionalModelRequestFields": self.config["additional_request_fields"]} - if self.config.get("additional_request_fields") - else {} - ), + **(self._get_additional_request_fields(tool_choice)), **( {"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]} if self.config.get("additional_response_field_paths") @@ -298,6 +294,34 @@ def _format_request( ), } + def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Get additional request fields, removing thinking if tool_choice forces tool use. + + Bedrock's API does not allow thinking mode when tool_choice forces tool use. + When forcing a tool (e.g., for structured_output retry), we temporarily disable thinking. + + Args: + tool_choice: The tool choice configuration. + + Returns: + A dict containing additionalModelRequestFields if configured, or empty dict. + """ + additional_fields = self.config.get("additional_request_fields") + if not additional_fields: + return {} + + # Check if tool_choice is forcing tool use ("any" or specific "tool") + is_forcing_tool = tool_choice is not None and ("any" in tool_choice or "tool" in tool_choice) + + if is_forcing_tool and "thinking" in additional_fields: + # Create a copy without the thinking key + fields_without_thinking = {k: v for k, v in additional_fields.items() if k != "thinking"} + if fields_without_thinking: + return {"additionalModelRequestFields": fields_without_thinking} + return {} + + return {"additionalModelRequestFields": additional_fields} + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. diff --git a/tests/strands/models/test_bedrock_thinking.py b/tests/strands/models/test_bedrock_thinking.py new file mode 100644 index 000000000..10b53cb03 --- /dev/null +++ b/tests/strands/models/test_bedrock_thinking.py @@ -0,0 +1,84 @@ +"""Tests for thinking mode behavior in BedrockModel.""" + +import pytest + +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def model_with_thinking(): + """Create a BedrockModel with thinking enabled.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"thinking": {"type": "enabled", "budget_tokens": 5000}}, + ) + + +@pytest.fixture +def model_without_thinking(): + """Create a BedrockModel without thinking.""" + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + +@pytest.fixture +def model_with_thinking_and_other_fields(): + """Create a BedrockModel with thinking and other additional fields.""" + return BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": {"type": "enabled", "budget_tokens": 5000}, + "some_other_field": "value", + }, + ) + + +def test_thinking_removed_when_forcing_tool_any(model_with_thinking): + """Thinking should be removed when tool_choice forces tool use with 'any'.""" + tool_choice = {"any": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_removed_when_forcing_specific_tool(model_with_thinking): + """Thinking should be removed when tool_choice forces a specific tool.""" + tool_choice = {"tool": {"name": "structured_output_tool"}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {} # thinking removed, no other fields + + +def test_thinking_preserved_with_auto_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is 'auto'.""" + tool_choice = {"auto": {}} + result = model_with_thinking._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_thinking_preserved_with_none_tool_choice(model_with_thinking): + """Thinking should be preserved when tool_choice is None.""" + result = model_with_thinking._get_additional_request_fields(None) + assert result == {"additionalModelRequestFields": {"thinking": {"type": "enabled", "budget_tokens": 5000}}} + + +def test_other_fields_preserved_when_thinking_removed(model_with_thinking_and_other_fields): + """Other additional fields should be preserved when thinking is removed.""" + tool_choice = {"any": {}} + result = model_with_thinking_and_other_fields._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_other_field": "value"}} + + +def test_no_fields_when_model_has_no_additional_fields(model_without_thinking): + """Should return empty dict when model has no additional_request_fields.""" + tool_choice = {"any": {}} + result = model_without_thinking._get_additional_request_fields(tool_choice) + assert result == {} + + +def test_fields_preserved_when_no_thinking_and_forcing_tool(): + """Additional fields without thinking should be preserved when forcing tool.""" + model = BedrockModel( + model_id="anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={"some_field": "value"}, + ) + tool_choice = {"any": {}} + result = model._get_additional_request_fields(tool_choice) + assert result == {"additionalModelRequestFields": {"some_field": "value"}} diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index b31f23663..0b3aa7b47 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -275,6 +275,43 @@ def test_redacted_content_handling(): assert isinstance(result.message["content"][0]["reasoningContent"]["redactedContent"], bytes) +def test_reasoning_content_in_messages_with_thinking_disabled(): + """Test that messages with reasoningContent are accepted when thinking is explicitly disabled.""" + # First, get a real reasoning response with thinking enabled + thinking_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": { + "type": "enabled", + "budget_tokens": 1024, + } + }, + ) + agent_with_thinking = Agent(model=thinking_model) + result_with_thinking = agent_with_thinking("What is 2+2?") + + # Verify we got reasoning content + assert "reasoningContent" in result_with_thinking.message["content"][0] + + # Now create a model with thinking disabled and use the messages from the thinking session + disabled_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + additional_request_fields={ + "thinking": { + "type": "disabled", + } + }, + ) + + # Use the conversation history that includes reasoning content + messages = agent_with_thinking.messages + + agent_disabled = Agent(model=disabled_model, messages=messages) + result = agent_disabled("What about 3+3?") + + assert result.stop_reason == "end_turn" + + def test_multi_prompt_system_content(): """Test multi-prompt system content blocks.""" system_prompt_content = [ From e4bd3bc9d77b9bf40b11c42f92775b17fe0c618e Mon Sep 17 00:00:00 2001 From: Bryce Cole Date: Fri, 16 Jan 2026 13:48:51 -0500 Subject: [PATCH 070/279] fix: a2a use artifact update event (#1401) fix: update tests fix: simplify code by storing in class fix: remove uneeded code change fix: hide a2a artifact streaming under feature flag fix: use walrus operator fix: use star to signify end of unnamed fix: add check for walrus legacy fix: clarify enable_a2a_compliant_streaming parameter in StrandsA2AExecutor initialization fix: update tests refactor: streamline artifact addition logic in StrandsA2AExecutor --- src/strands/multiagent/a2a/executor.py | 84 ++++++++++++++--- src/strands/multiagent/a2a/server.py | 8 +- tests/strands/multiagent/a2a/test_executor.py | 93 +++++++++++++++++++ 3 files changed, 172 insertions(+), 13 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index f02b8c6cc..58dfcc045 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -12,6 +12,8 @@ import json import logging import mimetypes +import uuid +import warnings from typing import Any, Literal from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -49,13 +51,21 @@ class StrandsA2AExecutor(AgentExecutor): # Handle special cases where format differs from extension FORMAT_MAPPINGS = {"jpg": "jpeg", "htm": "html", "3gp": "three_gp", "3gpp": "three_gp", "3g2": "three_gp"} - def __init__(self, agent: SAAgent): + # A2A-compliant streaming mode + _current_artifact_id: str | None + _is_first_chunk: bool + + def __init__(self, agent: SAAgent, *, enable_a2a_compliant_streaming: bool = False): """Initialize a StrandsA2AExecutor. Args: agent: The Strands Agent instance to adapt to the A2A protocol. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.agent = agent + self.enable_a2a_compliant_streaming = enable_a2a_compliant_streaming async def execute( self, @@ -104,12 +114,30 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater else: raise ValueError("No content blocks available") + if not self.enable_a2a_compliant_streaming: + warnings.warn( + "The default A2A response stream implemented in the strands sdk does not conform to " + "what is expected in the A2A spec. Please set the `enable_a2a_compliant_streaming` " + "boolean to `True` on your `A2AServer` class to properly conform to the spec. " + "In the next major version release, this will be the default behavior.", + UserWarning, + stacklevel=3, + ) + + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = str(uuid.uuid4()) + self._is_first_chunk = True + try: async for event in self.agent.stream_async(content_blocks): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") raise + finally: + if self.enable_a2a_compliant_streaming: + self._current_artifact_id = None + self._is_first_chunk = True async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -125,28 +153,60 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda logger.debug("Streaming event: %s", event) if "data" in event: if text_content := event["data"]: - await updater.update_status( - TaskState.working, - new_agent_text_message( - text_content, - updater.context_id, - updater.task_id, - ), - ) + if self.enable_a2a_compliant_streaming: + await updater.add_artifact( + [Part(root=TextPart(text=text_content))], + artifact_id=self._current_artifact_id, + name="agent_response", + append=not self._is_first_chunk, + ) + self._is_first_chunk = False + else: + # Legacy use update_status with agent message + await updater.update_status( + TaskState.working, + new_agent_text_message( + text_content, + updater.context_id, + updater.task_id, + ), + ) elif "result" in event: await self._handle_agent_result(event["result"], updater) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. - Processes the agent's final result, extracts text content from the response, - and adds it as an artifact to the task before marking the task as complete. + For A2A-compliant streaming: sends the final artifact chunk marker and marks + the task as complete. If no data chunks were previously sent, includes the + result content. + + For legacy streaming: adds the final result as a simple artifact without + artifact_id tracking. Args: result: The agent result object containing the final response, or None if no result. updater: The task updater for managing task state and adding the final artifact. """ - if final_content := str(result): + if self.enable_a2a_compliant_streaming: + if self._is_first_chunk: + final_content = str(result) if result else "" + parts = [Part(root=TextPart(text=final_content))] if final_content else [] + await updater.add_artifact( + parts, + artifact_id=self._current_artifact_id, + name="agent_response", + last_chunk=True, + ) + else: + await updater.add_artifact( + [], + artifact_id=self._current_artifact_id, + name="agent_response", + append=True, + last_chunk=True, + ) + elif final_content := str(result): await updater.add_artifact( [Part(root=TextPart(text=final_content))], name="agent_response", diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index a9093742f..7b4c4c73a 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -42,6 +42,7 @@ def __init__( queue_manager: QueueManager | None = None, push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, + enable_a2a_compliant_streaming: bool = False, ): """Initialize an A2A-compatible server from a Strands agent. @@ -66,6 +67,9 @@ def __init__( no push notification configuration is used. push_sender: Custom push notification sender implementation. If None, no push notifications are sent. + enable_a2a_compliant_streaming: If True, uses A2A-compliant streaming with + artifact updates. If False, uses legacy status updates streaming behavior + for backwards compatibility. Defaults to False. """ self.host = host self.port = port @@ -90,7 +94,9 @@ def __init__( self.description = self.strands_agent.description self.capabilities = AgentCapabilities(streaming=True) self.request_handler = DefaultRequestHandler( - agent_executor=StrandsA2AExecutor(self.strands_agent), + agent_executor=StrandsA2AExecutor( + self.strands_agent, enable_a2a_compliant_streaming=enable_a2a_compliant_streaming + ), task_store=task_store or InMemoryTaskStore(), queue_manager=queue_manager, push_config_store=push_config_store, diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 1463d3f48..73ade574e 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1020,3 +1020,96 @@ def test_default_formats_modularization(): assert executor._get_file_format_from_mime_type("", "document") == "txt" assert executor._get_file_format_from_mime_type("", "image") == "png" assert executor._get_file_format_from_mime_type("", "video") == "mp4" + + +# Tests for enable_a2a_compliant_streaming parameter + + +@pytest.mark.asyncio +async def test_legacy_mode_emits_deprecation_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that legacy streaming (default) emits deprecation warning.""" + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent) # Default is False + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with pytest.warns(UserWarning, match="does not conform to what is expected in the A2A spec"): + await executor.execute(mock_request_context, mock_event_queue) + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_no_warning(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that A2A-compliant mode does not emit warning.""" + import warnings + + from a2a.types import TextPart + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + + # Mock stream_async + async def mock_stream(content_blocks): + yield {"result": None} + + mock_strands_agent.stream_async = MagicMock(return_value=mock_stream([])) + + # Mock task + mock_task = MagicMock() + mock_task.id = "test-task-id" + mock_task.context_id = "test-context-id" + mock_request_context.current_task = mock_task + + # Mock message + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + with warnings.catch_warnings(): + warnings.simplefilter("error") + try: + await executor.execute(mock_request_context, mock_event_queue) + except UserWarning: + pytest.fail("Should not emit warning") + + +@pytest.mark.asyncio +async def test_a2a_compliant_mode_uses_add_artifact(mock_strands_agent): + """Test that A2A-compliant mode uses add_artifact with artifact_id.""" + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-123" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.update_status = AsyncMock() + + event = {"data": "content"} + await executor._handle_streaming_event(event, mock_updater) + + mock_updater.add_artifact.assert_called_once() + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-123" + assert mock_updater.add_artifact.call_args[1]["append"] is False + mock_updater.update_status.assert_not_called() From 51cbe7b6e9450f91cc8862120e69f6c1ac8bc96d Mon Sep 17 00:00:00 2001 From: Zezhen Xu <32421101+CrysisDeu@users.noreply.github.com> Date: Tue, 20 Jan 2026 07:03:37 -0800 Subject: [PATCH 071/279] Add parallel reading support to S3SessionManager.list_messages() (#1186) Co-authored-by: Jack Yuan Co-authored-by: Nicholas Clegg --- src/strands/session/s3_session_manager.py | 51 +++++++++++++++++-- tests/strands/models/test_bedrock.py | 9 ++-- .../session/test_s3_session_manager.py | 34 +++++++++++++ 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index e5713e5b7..8d557e81c 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,6 +2,7 @@ import json import logging +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, Any, cast import boto3 @@ -259,7 +260,21 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio def list_messages( self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: - """List messages for an agent with pagination from S3.""" + """List messages for an agent with pagination from S3. + + Args: + session_id: ID of the session + agent_id: ID of the agent + limit: Optional limit on number of messages to return + offset: Optional offset for pagination + **kwargs: Additional keyword arguments + + Returns: + List of SessionMessage objects, sorted by message_id. + + Raises: + SessionException: If S3 error occurs during message retrieval. + """ messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: paginator = self.client.get_paginator("list_objects_v2") @@ -287,10 +302,38 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load only the required message objects + # Load message objects in parallel for better performance messages: list[SessionMessage] = [] - for key in message_keys: - message_data = self._read_s3_object(key) + if not message_keys: + return messages + + # Optimize for single worker case - avoid thread pool overhead + if len(message_keys) == 1: + for key in message_keys: + message_data = self._read_s3_object(key) + if message_data: + messages.append(SessionMessage.from_dict(message_data)) + return messages + + with ThreadPoolExecutor() as executor: + # Submit all read tasks + future_to_key = {executor.submit(self._read_s3_object, key): key for key in message_keys} + + # Create a mapping from key to index to maintain order + key_to_index = {key: idx for idx, key in enumerate(message_keys)} + + # Initialize results list with None placeholders to maintain order + results: list[dict[str, Any] | None] = [None] * len(message_keys) + + # Process results as they complete + for future in as_completed(future_to_key): + key = future_to_key[future] + message_data = future.result() + # Store result at the correct index to maintain order + results[key_to_index[key]] = message_data + + # Convert results to SessionMessage objects, filtering out None values + for message_data in results: if message_data: messages.append(SessionMessage.from_dict(message_data)) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 7697c5e03..833b14729 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -201,10 +201,11 @@ def test__init__region_precedence(mock_client_method, session_cls): def test__init__with_endpoint_url(mock_client_method): """Test that BedrockModel uses the provided endpoint_url for VPC endpoints.""" custom_endpoint = "https://vpce-12345-abcde.bedrock-runtime.us-west-2.vpce.amazonaws.com" - BedrockModel(endpoint_url=custom_endpoint) - mock_client_method.assert_called_with( - region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint - ) + with unittest.mock.patch.object(os, "environ", {}): + BedrockModel(endpoint_url=custom_endpoint) + mock_client_method.assert_called_with( + region_name=DEFAULT_BEDROCK_REGION, config=ANY, service_name=ANY, endpoint_url=custom_endpoint + ) def test__init__with_region_and_session_raises_value_error(): diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index 719fbc2c9..c1c89da5b 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -282,6 +282,40 @@ def test_list_messages_all(s3_manager, sample_session, sample_agent): assert len(result) == 5 +def test_list_messages_single_message(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # Create single message + message = SessionMessage( + { + "role": "user", + "content": [ContentBlock(text="Single Message")], + }, + 0, + ) + s3_manager.create_message(sample_session.session_id, sample_agent.agent_id, message) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 1 + + +def test_list_no_messages(s3_manager, sample_session, sample_agent): + """Test listing all messages from S3.""" + # Create session and agent + s3_manager.create_session(sample_session) + s3_manager.create_agent(sample_session.session_id, sample_agent) + + # List all messages + result = s3_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + assert len(result) == 0 + + def test_list_messages_with_pagination(s3_manager, sample_session, sample_agent): """Test listing messages with pagination in S3.""" # Create session and agent From 8b7f6ccfd483c1120d22871b9fb4434d8783282c Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 20 Jan 2026 17:17:57 +0200 Subject: [PATCH 072/279] feat(steering): allow steering on AfterModelCallEvents (#1429) --- .gitignore | 1 + src/strands/experimental/steering/__init__.py | 7 +- .../experimental/steering/core/__init__.py | 4 +- .../experimental/steering/core/action.py | 51 ++-- .../experimental/steering/core/handler.py | 141 ++++++++-- .../steering/handlers/llm/llm_handler.py | 8 +- .../steering/core/test_handler.py | 245 ++++++++++++++++-- .../steering/handlers/llm/test_llm_handler.py | 12 +- tests/strands/tools/test_decorator_pep563.py | 4 +- tests_integ/steering/test_model_steering.py | 204 +++++++++++++++ ...t_llm_handler.py => test_tool_steering.py} | 10 +- 11 files changed, 597 insertions(+), 90 deletions(-) create mode 100644 tests_integ/steering/test_model_steering.py rename tests_integ/steering/{test_llm_handler.py => test_tool_steering.py} (91%) diff --git a/.gitignore b/.gitignore index 8b0fd989c..0b1375b50 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ repl_state .kiro uv.lock .audio_cache +CLAUDE.md diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index 4d0775873..be04a9ddb 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -9,7 +9,7 @@ - SteeringHandler: Base class for guidance logic with local context - SteeringContextCallback: Protocol for context update functions - SteeringContextProvider: Protocol for multi-event context providers -- SteeringAction: Proceed/Guide/Interrupt decisions +- ToolSteeringAction/ModelSteeringAction: Proceed/Guide/Interrupt decisions Usage: handler = LLMSteeringHandler(system_prompt="...") @@ -23,7 +23,7 @@ LedgerBeforeToolCall, LedgerProvider, ) -from .core.action import Guide, Interrupt, Proceed, SteeringAction +from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .core.context import SteeringContextCallback, SteeringContextProvider from .core.handler import SteeringHandler @@ -31,7 +31,8 @@ from .handlers.llm import LLMPromptMapper, LLMSteeringHandler __all__ = [ - "SteeringAction", + "ToolSteeringAction", + "ModelSteeringAction", "Proceed", "Guide", "Interrupt", diff --git a/src/strands/experimental/steering/core/__init__.py b/src/strands/experimental/steering/core/__init__.py index a3efe0dbc..cdd0d8269 100644 --- a/src/strands/experimental/steering/core/__init__.py +++ b/src/strands/experimental/steering/core/__init__.py @@ -1,6 +1,6 @@ """Core steering system interfaces and base classes.""" -from .action import Guide, Interrupt, Proceed, SteeringAction +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .handler import SteeringHandler -__all__ = ["SteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] +__all__ = ["ToolSteeringAction", "ModelSteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py index 8b4ec141d..b1f124b40 100644 --- a/src/strands/experimental/steering/core/action.py +++ b/src/strands/experimental/steering/core/action.py @@ -1,18 +1,18 @@ """SteeringAction types for steering evaluation results. -Defines structured outcomes from steering handlers that determine how tool calls +Defines structured outcomes from steering handlers that determine how agent actions should be handled. SteeringActions enable modular prompting by providing just-in-time feedback rather than front-loading all instructions in monolithic prompts. Flow: - SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling - ↓ ↓ ↓ - Evaluate context Action type Tool execution modified + SteeringHandler.steer_*() → SteeringAction → Event handling + ↓ ↓ ↓ + Evaluate context Action type Execution modified SteeringAction types: - Proceed: Tool executes immediately (no intervention needed) - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system + Proceed: Allow execution to continue without intervention + Guide: Provide contextual guidance to redirect the agent + Interrupt: Pause execution for human input Extensibility: New action types can be added to the union. Always handle the default @@ -25,9 +25,9 @@ class Proceed(BaseModel): - """Allow tool to execute immediately without intervention. + """Allow execution to continue without intervention. - The tool call proceeds as planned. The reason provides context + The action proceeds as planned. The reason provides context for logging and debugging purposes. """ @@ -36,11 +36,11 @@ class Proceed(BaseModel): class Guide(BaseModel): - """Cancel tool and provide contextual feedback for agent to explore alternatives. + """Provide contextual guidance to redirect the agent. - The tool call is cancelled and the agent receives the reason as contextual - feedback to help them consider alternative approaches while maintaining - adaptive reasoning capabilities. + The agent receives the reason as contextual feedback to help guide + its behavior. The specific handling depends on the steering context + (e.g., tool call vs. model response). """ type: Literal["guide"] = "guide" @@ -48,18 +48,29 @@ class Guide(BaseModel): class Interrupt(BaseModel): - """Pause tool execution for human input via interrupt system. + """Pause execution for human input via interrupt system. - The tool call is paused and human input is requested through Strands' + Execution is paused and human input is requested through Strands' interrupt system. The human can approve or deny the operation, and their - decision determines whether the tool executes or is cancelled. + decision determines whether execution continues or is cancelled. """ type: Literal["interrupt"] = "interrupt" reason: str -# SteeringAction union - extensible for future action types -# IMPORTANT: Always handle the default case when pattern matching -# to maintain backward compatibility as new action types are added -SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +# Context-specific steering action types +ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Steering actions valid for tool steering (steer_before_tool). + +- Proceed: Allow tool execution to continue +- Guide: Cancel tool and provide feedback for alternative approaches +- Interrupt: Pause for human input before tool execution +""" + +ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] +"""Steering actions valid for model steering (steer_after_model). + +- Proceed: Accept model response without modification +- Guide: Discard model response and retry with guidance +""" diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 4a0bcaa6a..fd00a27fc 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -2,38 +2,48 @@ Provides modular prompting through contextual guidance that appears when relevant, rather than front-loading all instructions. Handlers integrate with the Strands hook -system to intercept tool calls and provide just-in-time feedback based on local context. +system to intercept actions and provide just-in-time feedback based on local context. Architecture: - BeforeToolCallEvent → Context Callbacks → Update steering_context → steer() → SteeringAction - ↓ ↓ ↓ ↓ ↓ - Hook triggered Populate context Handler evaluates Handler decides Action taken + Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken Lifecycle: 1. Context callbacks update handler's steering_context on hook events - 2. BeforeToolCallEvent triggers steering evaluation via steer() method - 3. Handler accesses self.steering_context for guidance decisions - 4. SteeringAction determines tool execution: Proceed/Guide/Interrupt + 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering + 3. AfterModelCallEvent triggers steer_after_model() for model steering + 4. Handler accesses self.steering_context for guidance decisions + 5. SteeringAction determines execution flow Implementation: - Subclass SteeringHandler and implement steer() method. - Pass context_callbacks in constructor to register context update functions. + Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). + Both methods have default implementations that return Proceed, so you only need to + override the methods you want to customize. + Pass context_providers in constructor to register context update functions. Each handler maintains isolated steering_context that persists across calls. -SteeringAction handling: +SteeringAction handling for steer_before_tool: Proceed: Tool executes immediately Guide: Tool cancelled, agent receives contextual feedback to explore alternatives Interrupt: Tool execution paused for human input via interrupt system + +SteeringAction handling for steer_after_model: + Proceed: Model response accepted without modification + Guide: Discard model response and retry (message is dropped, model is called again) + Interrupt: Model response handling paused for human input via interrupt system """ import logging -from abc import ABC, abstractmethod +from abc import ABC from typing import TYPE_CHECKING, Any -from ....hooks.events import BeforeToolCallEvent +from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent from ....hooks.registry import HookProvider, HookRegistry +from ....types.content import Message +from ....types.streaming import StopReason from ....types.tools import ToolUse -from .action import Guide, Interrupt, Proceed, SteeringAction +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction from .context import SteeringContext, SteeringContextProvider if TYPE_CHECKING: @@ -73,24 +83,29 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) ) - # Register steering guidance - registry.add_callback(BeforeToolCallEvent, self._provide_steering_guidance) + # Register tool steering guidance + registry.add_callback(BeforeToolCallEvent, self._provide_tool_steering_guidance) + + # Register model steering guidance + registry.add_callback(AfterModelCallEvent, self._provide_model_steering_guidance) - async def _provide_steering_guidance(self, event: BeforeToolCallEvent) -> None: + async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] - logger.debug("tool_name=<%s> | providing steering guidance", tool_name) + logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) try: - action = await self.steer(event.agent, event.tool_use) + action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) except Exception as e: - logger.debug("tool_name=<%s>, error=<%s> | steering handler guidance failed", tool_name, e) + logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) return - self._handle_steering_action(action, event, tool_name) + self._handle_tool_steering_action(action, event, tool_name) - def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallEvent, tool_name: str) -> None: - """Handle the steering action by modifying tool execution flow. + def _handle_tool_steering_action( + self, action: ToolSteeringAction, event: BeforeToolCallEvent, tool_name: str + ) -> None: + """Handle the steering action for tool calls by modifying tool execution flow. Proceed: Tool executes normally Guide: Tool cancelled with contextual feedback for agent to consider alternatives @@ -114,11 +129,52 @@ def _handle_steering_action(self, action: SteeringAction, event: BeforeToolCallE else: logger.debug("tool_name=<%s> | tool call approved manually", tool_name) else: - raise ValueError(f"Unknown steering action type: {action}") + raise ValueError(f"Unknown steering action type for tool call: {action}") + + async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + """Provide steering guidance for model response.""" + logger.debug("providing model steering guidance") + + # Only steer on successful model responses + if event.stop_response is None: + logger.debug("no stop response available | skipping model steering") + return + + try: + action = await self.steer_after_model( + agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason + ) + except Exception as e: + logger.debug("error=<%s> | model steering handler guidance failed", e) + return + + await self._handle_model_steering_action(action, event) + + async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: + """Handle the steering action for model responses by modifying response handling flow. - @abstractmethod - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: - """Provide contextual guidance to help agent navigate complex workflows. + Proceed: Model response accepted without modification + Guide: Discard model response and retry with guidance message added to conversation + """ + if isinstance(action, Proceed): + logger.debug("model response proceeding") + elif isinstance(action, Guide): + logger.debug("model response guided (retrying): %s", action.reason) + # Set retry flag to discard current response + event.retry = True + # Add guidance message to agent's conversation so model sees it on retry + await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) + logger.debug("added guidance message to conversation for model retry") + else: + raise ValueError(f"Unknown steering action type for model response: {action}") + + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance before tool execution. + + This method is called before a tool is executed, allowing the handler to: + - Proceed: Allow tool execution to continue + - Guide: Cancel tool and provide feedback for alternative approaches + - Interrupt: Pause for human input before tool execution Args: agent: The agent instance @@ -126,9 +182,38 @@ async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> Steer **kwargs: Additional keyword arguments for guidance evaluation Returns: - SteeringAction indicating how to guide the agent's next action + ToolSteeringAction indicating how to guide the tool execution + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (allow tool execution) + Override this method to implement custom tool steering logic + """ + return Proceed(reason="Default implementation: allowing tool execution") + + async def steer_after_model( + self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any + ) -> ModelSteeringAction: + """Provide contextual guidance after model response. + + This method is called after the model generates a response, allowing the handler to: + - Proceed: Accept the model response without modification + - Guide: Discard the response and retry (message is dropped, model is called again) + + Note: Interrupt is not supported for model steering as the model has already responded. + + Args: + agent: The agent instance + message: The model's generated message + stop_reason: The reason the model stopped generating + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ModelSteeringAction indicating how to handle the model response Note: Access steering context via self.steering_context + Default implementation returns Proceed (accept response as-is) + Override this method to implement custom model steering logic """ - ... + return Proceed(reason="Default implementation: accepting model response") diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 4d90f46c9..379dc684a 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -10,7 +10,7 @@ from .....models import Model from .....types.tools import ToolUse from ...context_providers.ledger_provider import LedgerProvider -from ...core.action import Guide, Interrupt, Proceed, SteeringAction +from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction from ...core.context import SteeringContextProvider from ...core.handler import SteeringHandler from .mappers import DefaultPromptMapper, LLMPromptMapper @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer_before_tool(self, *, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: """Provide contextual guidance for tool usage. Args: @@ -67,7 +67,7 @@ async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> Steerin **kwargs: Additional keyword arguments for steering evaluation Returns: - SteeringAction indicating how to guide the agent's next action + SteeringAction indicating how to guide the tool execution """ # Generate steering prompt prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) @@ -91,5 +91,5 @@ async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> Steerin case "interrupt": return Interrupt(reason=llm_result.reason) case _: - logger.warning("decision=<%s> | uŹknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 8d5ef6884..a16208e5b 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,20 +1,20 @@ """Unit tests for steering handler base class.""" -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest from strands.experimental.steering.core.action import Guide, Interrupt, Proceed from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler -from strands.hooks.events import BeforeToolCallEvent +from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent from strands.hooks.registry import HookRegistry class TestSteeringHandler(SteeringHandler): """Test implementation of SteeringHandler.""" - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -31,9 +31,9 @@ def test_register_hooks(): handler.register_hooks(registry) - # Verify hooks were registered - assert registry.add_callback.call_count >= 1 - registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_steering_guidance) + # Verify hooks were registered (tool and model steering hooks) + assert registry.add_callback.call_count >= 2 + registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_tool_steering_guidance) def test_steering_context_initialization(): @@ -65,7 +65,7 @@ async def test_proceed_action_flow(): """Test complete flow with Proceed action.""" class ProceedHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") handler = ProceedHandler() @@ -73,7 +73,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -84,7 +84,7 @@ async def test_guide_action_flow(): """Test complete flow with Guide action.""" class GuideHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Guide(reason="Test guidance") handler = GuideHandler() @@ -92,7 +92,7 @@ async def steer(self, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" @@ -104,7 +104,7 @@ async def test_interrupt_action_approved_flow(): """Test complete flow with Interrupt action when approved.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -113,7 +113,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -123,7 +123,7 @@ async def test_interrupt_action_denied_flow(): """Test complete flow with Interrupt action when denied.""" class InterruptHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Interrupt(reason="Need approval") handler = InterruptHandler() @@ -132,7 +132,7 @@ async def steer(self, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -143,7 +143,7 @@ async def test_unknown_action_flow(): """Test complete flow with unknown action type raises error.""" class UnknownActionHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Mock() # Not a valid SteeringAction handler = UnknownActionHandler() @@ -152,14 +152,14 @@ async def steer(self, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_steering_guidance(event) + await handler._provide_tool_steering_guidance(event) def test_register_steering_hooks_override(): """Test that _register_steering_hooks can be overridden.""" class CustomHandler(SteeringHandler): - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Custom") def register_hooks(self, registry, **kwargs): @@ -200,7 +200,7 @@ def __init__(self, context_callbacks=None): providers = [MockContextProvider(context_callbacks)] if context_callbacks else None super().__init__(context_providers=providers) - async def steer(self, agent, tool_use, **kwargs): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Test proceed") @@ -260,8 +260,8 @@ def test_multiple_context_callbacks_registered(): handler.register_hooks(registry) - # Should register one callback for each context provider plus steering guidance - expected_calls = 2 + 1 # 2 callbacks + 1 for steering guidance + # Should register one callback for each context provider plus tool and model steering guidance + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert registry.add_callback.call_count >= expected_calls @@ -276,3 +276,208 @@ def test_handler_initialization_with_callbacks(): assert len(handler._context_callbacks) == 2 assert callback1 in handler._context_callbacks assert callback2 in handler._context_callbacks + + +# Model steering tests +@pytest.mark.asyncio +async def test_model_steering_proceed_action_flow(): + """Test model steering with Proceed action.""" + + class ModelProceedHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Proceed(reason="Model response accepted") + + handler = ModelProceedHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should not set retry for Proceed + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_steering_guide_action_flow(): + """Test model steering with Guide action sets retry and adds message.""" + + class ModelGuideHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Guide(reason="Please improve your response") + + handler = ModelGuideHandler() + agent = AsyncMock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + await handler._provide_model_steering_guidance(event) + + # Should set retry flag + assert event.retry is True + # Should add guidance message to conversation + agent._append_messages.assert_called_once() + call_args = agent._append_messages.call_args[0][0] + assert call_args["role"] == "user" + assert "Please improve your response" in call_args["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_model_steering_skips_when_no_stop_response(): + """Test model steering skips when stop_response is None.""" + + class ModelProceedHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.steer_called = False + + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + self.steer_called = True + return Proceed(reason="Should not be called") + + handler = ModelProceedHandler() + event = Mock(spec=AfterModelCallEvent) + event.stop_response = None + + await handler._provide_model_steering_guidance(event) + + # steer_after_model should not have been called + assert handler.steer_called is False + + +@pytest.mark.asyncio +async def test_model_steering_unknown_action_raises_error(): + """Test model steering with unknown action type raises error.""" + + class UnknownModelActionHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Mock() # Not a valid ModelSteeringAction + + handler = UnknownModelActionHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler._provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_interrupt_raises_error(): + """Test model steering with Interrupt action raises error (not supported for model steering).""" + + class InterruptModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + return Interrupt(reason="Should not be allowed") + + handler = InterruptModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + + with pytest.raises(ValueError, match="Unknown steering action type for model response"): + await handler._provide_model_steering_guidance(event) + + +@pytest.mark.asyncio +async def test_model_steering_exception_handling(): + """Test model steering handles exceptions gracefully.""" + + class ExceptionModelHandler(SteeringHandler): + async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionModelHandler() + agent = Mock() + stop_response = Mock() + stop_response.message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_response.stop_reason = "end_turn" + event = Mock(spec=AfterModelCallEvent) + event.agent = agent + event.stop_response = stop_response + event.retry = False + + # Should not raise, just return early + await handler._provide_model_steering_guidance(event) + + # retry should not be set since exception occurred + assert event.retry is False + + +@pytest.mark.asyncio +async def test_tool_steering_exception_handling(): + """Test tool steering handles exceptions gracefully.""" + + class ExceptionToolHandler(SteeringHandler): + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + raise RuntimeError("Test exception") + + handler = ExceptionToolHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) + + # Should not raise, just return early + await handler._provide_tool_steering_guidance(event) + + # cancel_tool should not be set since exception occurred + assert not event.cancel_tool + + +# Default implementation tests +@pytest.mark.asyncio +async def test_default_steer_before_tool_returns_proceed(): + """Test default steer_before_tool returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + tool_use = {"name": "test_tool"} + + # Call the parent's default implementation + result = await SteeringHandler.steer_before_tool(handler, agent=agent, tool_use=tool_use) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +@pytest.mark.asyncio +async def test_default_steer_after_model_returns_proceed(): + """Test default steer_after_model returns Proceed.""" + handler = TestSteeringHandler() + agent = Mock() + message = {"role": "assistant", "content": [{"text": "Hello"}]} + stop_reason = "end_turn" + + # Call the parent's default implementation + result = await SteeringHandler.steer_after_model(handler, agent=agent, message=message, stop_reason=stop_reason) + + assert isinstance(result, Proceed) + assert "Default implementation" in result.reason + + +def test_register_hooks_registers_model_steering(): + """Test that register_hooks registers model steering callback.""" + handler = TestSteeringHandler() + registry = Mock(spec=HookRegistry) + + handler.register_hooks(registry) + + # Verify model steering hook was registered + registry.add_callback.assert_any_call(AfterModelCallEvent, handler._provide_model_steering_guidance) diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py index f780088b5..f10254e50 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py +++ b/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py @@ -59,7 +59,7 @@ async def test_steer_proceed_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert result.reason == "Tool call is safe" @@ -82,7 +82,7 @@ async def test_steer_guide_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Guide) assert result.reason == "Consider security implications" @@ -105,7 +105,7 @@ async def test_steer_interrupt_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Interrupt) assert result.reason == "Human approval required" @@ -133,7 +133,7 @@ async def test_steer_unknown_decision(mock_agent_class): agent = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - result = await handler.steer(agent, tool_use) + result = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(result, Proceed) assert "Unknown LLM decision, defaulting to proceed" in result.reason @@ -158,7 +158,7 @@ async def test_steer_uses_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=custom_model, callback_handler=None) @@ -181,7 +181,7 @@ async def test_steer_uses_agent_model_when_no_custom_model(mock_agent_class): agent.model = Mock() tool_use = {"name": "test_tool", "input": {"param": "value"}} - await handler.steer(agent, tool_use) + await handler.steer_before_tool(agent=agent, tool_use=tool_use) mock_agent_class.assert_called_once_with(system_prompt=system_prompt, model=agent.model, callback_handler=None) diff --git a/tests/strands/tools/test_decorator_pep563.py b/tests/strands/tools/test_decorator_pep563.py index 07ec8f2ba..44d9a626a 100644 --- a/tests/strands/tools/test_decorator_pep563.py +++ b/tests/strands/tools/test_decorator_pep563.py @@ -10,10 +10,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal import pytest -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from strands import tool diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py new file mode 100644 index 000000000..e867ea033 --- /dev/null +++ b/tests_integ/steering/test_model_steering.py @@ -0,0 +1,204 @@ +"""Integration tests for model steering (steer_after_model).""" + +from strands import Agent, tool +from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed +from strands.experimental.steering.core.handler import SteeringHandler +from strands.types.content import Message +from strands.types.streaming import StopReason + + +class SimpleModelSteeringHandler(SteeringHandler): + """Simple handler that steers only on model responses.""" + + def __init__(self, should_guide: bool = False, guidance_message: str = ""): + """Initialize handler. + + Args: + should_guide: If True, guide (retry) on first model response + guidance_message: The guidance message to provide on retry + """ + super().__init__() + self.should_guide = should_guide + self.guidance_message = guidance_message + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + """Steer after model response.""" + self.call_count += 1 + + # On first call, guide to retry if configured + if self.should_guide and self.call_count == 1: + return Guide(reason=self.guidance_message) + + return Proceed(reason="Model response accepted") + + +def test_model_steering_proceeds_without_intervention(): + """Test that model steering can accept responses without modification.""" + handler = SimpleModelSteeringHandler(should_guide=False) + agent = Agent(hooks=[handler]) + + response = agent("What is 2+2?") + + # Handler should have been called once + assert handler.call_count >= 1 + # Response should be generated successfully + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_triggers_retry(): + """Test that Guide action triggers model retry.""" + handler = SimpleModelSteeringHandler(should_guide=True, guidance_message="Please provide a more detailed response.") + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Handler should have been called at least twice (first response + retry) + assert handler.call_count >= 2, "Handler should be called on initial response and retry" + + # Response should be generated successfully after retry + response_text = str(response) + assert response_text is not None + assert len(response_text) > 0 + + +def test_model_steering_guide_influences_retry_response(): + """Test that guidance message influences the retry response.""" + + class SpecificGuidanceHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.retry_done = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + if not self.retry_done: + self.retry_done = True + # Provide very specific guidance that should appear in retry + return Guide(reason="Please mention that Paris is also known as the 'City of Light'.") + return Proceed(reason="Response is good now") + + handler = SpecificGuidanceHandler() + agent = Agent(hooks=[handler]) + + response = agent("What is the capital of France?") + + # Verify retry happened + assert handler.retry_done, "Retry should have occurred" + + # Check that the response likely incorporated the guidance + output = str(response).lower() + assert "paris" in output, "Response should mention Paris" + + # The guidance should have influenced the retry (check for "light" or that retry happened) + # We can't guarantee the model will include it, but we verify the mechanism worked + assert handler.retry_done, "Guidance mechanism should have executed" + + +def test_model_steering_multiple_retries(): + """Test that model steering can guide multiple times before proceeding.""" + + class MultiRetryHandler(SteeringHandler): + def __init__(self): + super().__init__() + self.call_count = 0 + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + self.call_count += 1 + + # Retry twice + if self.call_count == 1: + return Guide(reason="Please provide more context.") + if self.call_count == 2: + return Guide(reason="Please add specific examples.") + return Proceed(reason="Response is good now") + + handler = MultiRetryHandler() + agent = Agent(hooks=[handler]) + + response = agent("Explain machine learning.") + + # Should have been called 3 times (2 guides + 1 proceed) + assert handler.call_count >= 3, "Handler should be called multiple times for multiple retries" + + # Response should still complete successfully + assert str(response) is not None + assert len(str(response)) > 0 + + +@tool +def log_activity(activity: str) -> str: + """Log an activity for audit purposes.""" + return f"Activity logged: {activity}" + + +def test_model_steering_forces_tool_usage_on_unrelated_prompt(): + """Test that steering forces tool usage even when prompt doesn't need the tool. + + This test verifies the flow: + 1. Agent has a logging tool available + 2. User asks an unrelated question (math problem) + 3. Model tries to answer directly without using the tool + 4. Steering intercepts and forces tool usage before termination + 5. Model uses the tool and then completes + """ + + class ForceToolUsageHandler(SteeringHandler): + """Handler that forces a specific tool to be used before allowing termination.""" + + def __init__(self, required_tool: str): + super().__init__() + self.required_tool = required_tool + self.tool_was_used = False + self.guidance_given = False + + async def steer_after_model( + self, *, agent: Agent, message: Message, stop_reason: StopReason, **kwargs + ) -> ModelSteeringAction: + # Only check when model is trying to end the turn + if stop_reason != "end_turn": + return Proceed(reason="Model still processing") + + # Check if the required tool was used in this message + content_blocks = message.get("content", []) + for block in content_blocks: + if "toolUse" in block and block["toolUse"].get("name") == self.required_tool: + self.tool_was_used = True + return Proceed(reason="Required tool was used") + + # If tool wasn't used and we haven't guided yet, force its usage + if not self.tool_was_used and not self.guidance_given: + self.guidance_given = True + return Guide( + reason=f"Before completing your response, you MUST use the {self.required_tool} tool " + "to log this interaction. Call the tool with a brief description of what you did." + ) + + # Allow completion after guidance was given (model may have used tool in retry) + return Proceed(reason="Guidance was provided") + + handler = ForceToolUsageHandler(required_tool="log_activity") + agent = Agent(tools=[log_activity], hooks=[handler]) + + # Ask a question that clearly doesn't need the logging tool + response = agent("What is 2 + 2?") + + # Verify the steering mechanism worked + assert handler.guidance_given, "Handler should have provided guidance to use the tool" + + # Verify tool was actually called by checking metrics + tool_metrics = response.metrics.tool_metrics + assert "log_activity" in tool_metrics, "log_activity tool should have been called" + assert tool_metrics["log_activity"].call_count >= 1, "log_activity should have been called at least once" + assert tool_metrics["log_activity"].success_count >= 1, "log_activity should have succeeded" + + # Verify the response still answers the original question + output = str(response).lower() + assert "4" in output, "Response should contain the answer to 2+2" diff --git a/tests_integ/steering/test_llm_handler.py b/tests_integ/steering/test_tool_steering.py similarity index 91% rename from tests_integ/steering/test_llm_handler.py rename to tests_integ/steering/test_tool_steering.py index 8a8cebea2..eced94ba0 100644 --- a/tests_integ/steering/test_llm_handler.py +++ b/tests_integ/steering/test_tool_steering.py @@ -1,4 +1,4 @@ -"""Integration tests for LLM steering handler.""" +"""Integration tests for tool steering (steer_before_tool).""" import pytest @@ -30,7 +30,7 @@ async def test_llm_steering_handler_proceed(): agent = Agent(tools=[send_notification]) tool_use = {"name": "send_notification", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Proceed) @@ -48,7 +48,7 @@ async def test_llm_steering_handler_guide(): agent = Agent(tools=[send_email, send_notification]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Guide) @@ -64,12 +64,12 @@ async def test_llm_steering_handler_interrupt(): agent = Agent(tools=[send_email]) tool_use = {"name": "send_email", "input": {"recipient": "user", "message": "hello"}} - effect = await handler.steer(agent, tool_use) + effect = await handler.steer_before_tool(agent=agent, tool_use=tool_use) assert isinstance(effect, Interrupt) -def test_agent_with_steering_e2e(): +def test_agent_with_tool_steering_e2e(): """End-to-end test of agent with steering handler guiding tool choice.""" handler = LLMSteeringHandler( system_prompt=( From 63e58aa83dbb63fab06c405cbeb2acaa500c9e32 Mon Sep 17 00:00:00 2001 From: Qian Zhang Date: Tue, 20 Jan 2026 16:18:52 +0100 Subject: [PATCH 073/279] fix: provide unique toolUseId for gemini models (#1201) Co-authored-by: spicadust Co-authored-by: Patrick Gray --- src/strands/models/gemini.py | 28 ++++++++---- tests/strands/models/test_gemini.py | 67 ++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 52d45b649..5417f20b3 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -6,6 +6,7 @@ import json import logging import mimetypes +import secrets from collections.abc import AsyncGenerator from typing import Any, TypedDict, TypeVar, cast @@ -86,6 +87,7 @@ def __init__( self._custom_client = client self.client_args = client_args or {} + self._tool_use_id_to_name: dict[str, str] = {} # Validate gemini_tools if provided if "gemini_tools" in self.config: @@ -173,10 +175,13 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par return genai.types.Part(text=content["text"]) if "toolResult" in content: + tool_use_id = content["toolResult"]["toolUseId"] + function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id) + return genai.types.Part( function_response=genai.types.FunctionResponse( - id=content["toolResult"]["toolUseId"], - name=content["toolResult"]["toolUseId"], + id=tool_use_id, + name=function_name, response={ "output": [ tool_result_content @@ -191,6 +196,12 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: + # Store the mapping from toolUseId to name for later use in toolResult formatting. + # This mapping is built as we format the request, ensuring that when we encounter + # toolResult blocks (which come after toolUse blocks in the message history), + # we can look up the function name. + self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], @@ -317,16 +328,16 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_start": match event["data_type"]: case "tool": - # Note: toolUseId is the only identifier available in a tool result. However, Gemini requires - # that name be set in the equivalent FunctionResponse type. Consequently, we assign - # function name to toolUseId in our tool use block. And another reason, function_call is - # not guaranteed to have id populated. + function_call = event["data"].function_call + # Use Gemini's provided ID or generate one if missing + tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}" + return { "contentBlockStart": { "start": { "toolUse": { - "name": event["data"].function_call.name, - "toolUseId": event["data"].function_call.name, + "name": function_call.name, + "toolUseId": tool_use_id, }, }, }, @@ -417,6 +428,7 @@ async def stream( ModelThrottledException: If the request is throttled by Gemini. """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) + self._tool_use_id_to_name.clear() client = self._get_client().aio diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 08be9188d..70f5032d8 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -360,6 +360,71 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id): gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) +@pytest.mark.asyncio +async def test_stream_request_with_tool_results_preserving_name(gemini_client, model, model_id): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_1", + "input": {}, + }, + }, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "done"}], + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [{"function_declarations": []}], + }, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {}, + "id": "t1", + "name": "tool_1", + }, + }, + ], + "role": "model", + }, + { + "parts": [ + { + "function_response": { + "id": "t1", + "name": "tool_1", + "response": {"output": [{"text": "done"}]}, + }, + }, + ], + "role": "user", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_request_with_empty_content(gemini_client, model, model_id): messages = [ @@ -459,7 +524,7 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, {"contentBlockStop": {}}, {"contentBlockStop": {}}, From 456b70a0c14b255eafb49442e9663905b8ba5eba Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 20 Jan 2026 13:02:31 -0500 Subject: [PATCH 074/279] gemini - tool_use_id_to_name - local (#1521) --- src/strands/models/gemini.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 5417f20b3..855e1ef5c 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -87,7 +87,6 @@ def __init__( self._custom_client = client self.client_args = client_args or {} - self._tool_use_id_to_name: dict[str, str] = {} # Validate gemini_tools if provided if "gemini_tools" in self.config: @@ -135,13 +134,19 @@ def _get_client(self) -> genai.Client: # Create a new client from client_args return genai.Client(**self.client_args) - def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part: + def _format_request_content_part( + self, content: ContentBlock, tool_use_id_to_name: dict[str, str] + ) -> genai.types.Part: """Format content block into a Gemini part instance. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Part Args: content: Message content to format. + tool_use_id_to_name: Mapping of tool use id to tool name. + Store the mapping from toolUseId to name for later use in toolResult formatting. This mapping is built + as we format the request, ensuring that when we encounter toolResult blocks (which come after toolUse + blocks in the message history), we can look up the function name. Returns: Gemini part. @@ -176,7 +181,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par if "toolResult" in content: tool_use_id = content["toolResult"]["toolUseId"] - function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id) + function_name = tool_use_id_to_name.get(tool_use_id, tool_use_id) return genai.types.Part( function_response=genai.types.FunctionResponse( @@ -187,7 +192,8 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par tool_result_content if "json" in tool_result_content else self._format_request_content_part( - cast(ContentBlock, tool_result_content) + cast(ContentBlock, tool_result_content), + tool_use_id_to_name, ).to_json_dict() for tool_result_content in content["toolResult"]["content"] ], @@ -196,11 +202,7 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par ) if "toolUse" in content: - # Store the mapping from toolUseId to name for later use in toolResult formatting. - # This mapping is built as we format the request, ensuring that when we encounter - # toolResult blocks (which come after toolUse blocks in the message history), - # we can look up the function name. - self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] return genai.types.Part( function_call=genai.types.FunctionCall( @@ -223,9 +225,15 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten Returns: Gemini content list. """ + # Gemini FunctionResponses are constructed from tool result blocks. Function name is required but is not + # available in tool result blocks, hence the mapping. + tool_use_id_to_name: dict[str, str] = {} + return [ genai.types.Content( - parts=[self._format_request_content_part(content) for content in message["content"]], + parts=[ + self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"] + ], role="user" if message["role"] == "user" else "model", ) for message in messages @@ -428,7 +436,6 @@ async def stream( ModelThrottledException: If the request is throttled by Gemini. """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) - self._tool_use_id_to_name.clear() client = self._get_client().aio From 6dcd24739d7a153eed8eb778d795bb9df6cd3fc3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 20 Jan 2026 21:12:36 +0200 Subject: [PATCH 075/279] fix(litellm): handle missing usage attribute on ModelResponseStream (#1520) --- src/strands/models/litellm.py | 4 +- tests/strands/models/test_litellm.py | 101 +++++++++++++++++++++++ tests_integ/models/test_model_litellm.py | 21 +++++ 3 files changed, 124 insertions(+), 2 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index ae71cc668..ec6579c58 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -547,8 +547,8 @@ async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> A # Skip remaining events as we don't have use for anything except the final usage payload async for event in response: _ = event - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + if usage := getattr(event, "usage", None): + yield self.format_chunk({"chunk_type": "metadata", "data": usage}) logger.debug("finished streaming response from model") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 99df22a3f..f5e1837bf 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -711,3 +711,104 @@ def test_stream_switch_content_different_type_no_prev(): assert len(chunks) == 1 assert chunks[0]["contentBlockStart"] == {"start": {}} assert data_type == "text" + + +@pytest.mark.asyncio +async def test_stream_with_events_missing_usage_attribute( + litellm_acompletion, api_key, model_id, model, agenerator, alist +): + """Test streaming handles events that don't have a usage attribute. + + This test verifies the fix for a bug where ModelResponseStream objects + (which don't have a 'usage' attribute) would cause an AttributeError + when the code tried to access event.usage directly instead of using getattr. + + The bug occurred because: + 1. ModelResponse (non-streaming) has a 'usage' attribute + 2. ModelResponseStream (streaming chunks) does NOT have a 'usage' attribute + 3. The code assumed all events would have the 'usage' attribute + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + + # Use spec to ensure mock objects only have specified attributes + # This mimics the real ModelResponseStream which doesn't have 'usage' + class MockStreamChunk: + """Mock that mimics ModelResponseStream - no usage attribute.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunk(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + # After finish_reason is received, remaining events in the stream also don't have 'usage' + mock_event_3 = MockStreamChunk(choices=[]) + mock_event_4 = MockStreamChunk(choices=[]) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}] + response = model.stream(messages) + + # This should NOT raise AttributeError: 'MockStreamChunk' object has no attribute 'usage' + tru_events = await alist(response) + + # Verify we got the expected events (no metadata since no usage was available) + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert {"messageStop": {"stopReason": "end_turn"}} in tru_events + # No metadata event since mock events don't have usage + assert not any("metadata" in event for event in tru_events) + + +@pytest.mark.asyncio +async def test_stream_with_usage_in_final_event(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test streaming correctly extracts usage when it IS present in final events. + + This test ensures that when usage data IS available (e.g., with stream_options.include_usage=True), + it is correctly extracted and included in the metadata event. + """ + + class MockStreamChunkWithoutUsage: + """Mock streaming chunk without usage.""" + + def __init__(self, choices=None): + self.choices = choices or [] + + class MockStreamChunkWithUsage: + """Mock streaming chunk with usage (final event).""" + + def __init__(self, usage): + self.choices = [] + self.usage = usage + + mock_delta = unittest.mock.Mock(content="Hi", tool_calls=None, reasoning_content=None) + mock_event_1 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = MockStreamChunkWithoutUsage(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage data + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + mock_usage.total_tokens = 15 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_event_3 = MockStreamChunkWithUsage(usage=mock_usage) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + # Verify metadata event is present with correct usage + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 + assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 5 + assert metadata_events[0]["metadata"]["usage"]["totalTokens"] == 15 diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index 80e21bdfd..eb0737e0f 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -236,6 +236,27 @@ def test_structured_output_unsupported_model(model, nested_weather): mock_schema.assert_not_called() +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_streaming_returns_usage_metrics(model_fixture, request): + """Test that streaming returns usage metrics. + + This test verifies that the streaming flow correctly extracts and returns + usage data from the model response. This is a regression test for the bug + where accessing 'usage' attribute on ModelResponseStream raised AttributeError. + + Regression test for: 'ModelResponseStream' object has no attribute 'usage' + """ + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) + result = agent("Say hello") + + # Verify usage metrics are returned - this would fail if streaming breaks + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + + @pytest.mark.asyncio async def test_cache_read_tokens_multi_turn(model): """Integration test for cache read tokens in multi-turn conversation.""" From 64e1bb25e2e462d2bd42b66d048a8782d674223a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:41:08 -0500 Subject: [PATCH 076/279] feat(agent): add configurable retry_strategy for model calls (#1424) The current retry logic for handling ModelThrottledException is hardcoded in event_loop.py with fixed values (6 attempts, exponential backoff starting at 4s). This makes it impossible for users to customize retry behavior for their specific use cases, such as: This refactors the hardcoded retry logic into a `ModelRetryStrategy` class so that folks can customize the parameters. **Not Included**: The does not introduce a `RetryStrategy` base class. I started to do so, but am deferring it because: 1. It requires some additional design work to accommodate the tool-retries, which I anticipate should be accounted for in the design 2. It simplifies this review which refactors how the default retries work internally 3. `ModelRetryStrategy` provides enough benefit to allow folks to customize the agent loop without blocking on a more extensible design ---- Co-authored-by: Strands Agent Co-authored-by: Mackenzie Zastrow --- src/strands/__init__.py | 2 + src/strands/agent/__init__.py | 3 + src/strands/agent/agent.py | 21 +- src/strands/event_loop/_retry.py | 157 +++++++++ src/strands/event_loop/event_loop.py | 48 +-- tests/strands/agent/conftest.py | 22 ++ .../strands/agent/hooks/test_agent_events.py | 10 +- tests/strands/agent/test_agent_hooks.py | 2 +- tests/strands/agent/test_agent_retry.py | 161 +++++++++ tests/strands/agent/test_retry.py | 328 ++++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 27 +- 11 files changed, 736 insertions(+), 45 deletions(-) create mode 100644 src/strands/event_loop/_retry.py create mode 100644 tests/strands/agent/conftest.py create mode 100644 tests/strands/agent/test_agent_retry.py create mode 100644 tests/strands/agent/test_retry.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index bc17497a0..6026d4240 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -3,6 +3,7 @@ from . import agent, models, telemetry, types from .agent.agent import Agent from .agent.base import AgentBase +from .event_loop._retry import ModelRetryStrategy from .tools.decorator import tool from .types.tools import ToolContext @@ -11,6 +12,7 @@ "AgentBase", "agent", "models", + "ModelRetryStrategy", "tool", "ToolContext", "types", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c00623dc2..2e40866a9 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -4,8 +4,10 @@ - Agent: The main interface for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows +- Retry Strategies: Configurable retry behavior for model calls """ +from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -24,4 +26,5 @@ "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "ModelRetryStrategy", ] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7b9e9c914..cacc69ece 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -26,7 +26,8 @@ from .. import _identifier from .._async import run_async -from ..event_loop.event_loop import event_loop_cycle +from ..event_loop._retry import ModelRetryStrategy +from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: @@ -118,6 +119,7 @@ def __init__( hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, tool_executor: ToolExecutor | None = None, + retry_strategy: ModelRetryStrategy | None = None, ): """Initialize the Agent with the specified configuration. @@ -167,6 +169,9 @@ def __init__( session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + retry_strategy: Strategy for retrying model calls on throttling or other transient errors. + Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. + Implement a custom HookProvider for custom retry logic, or pass None to disable retries. Raises: ValueError: If agent id contains path separators. @@ -244,6 +249,17 @@ def __init__( # separate event loops in different threads, so asyncio.Lock wouldn't work self._invocation_lock = threading.Lock() + # In the future, we'll have a RetryStrategy base class but until + # that API is determined we only allow ModelRetryStrategy + if retry_strategy and type(retry_strategy) is not ModelRetryStrategy: + raise ValueError("retry_strategy must be an instance of ModelRetryStrategy") + + self._retry_strategy = ( + retry_strategy + if retry_strategy is not None + else ModelRetryStrategy(max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY) + ) + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -252,6 +268,9 @@ def __init__( # Allow conversation_managers to subscribe to hooks self.hooks.add_hook(self.conversation_manager) + # Register retry strategy as a hook + self.hooks.add_hook(self._retry_strategy) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: diff --git a/src/strands/event_loop/_retry.py b/src/strands/event_loop/_retry.py new file mode 100644 index 000000000..04a6101b8 --- /dev/null +++ b/src/strands/event_loop/_retry.py @@ -0,0 +1,157 @@ +"""Retry strategy implementations for handling model throttling and other retry scenarios. + +This module provides hook-based retry strategies that can be configured on the Agent +to control retry behavior for model invocations. Retry strategies implement the +HookProvider protocol and register callbacks for AfterModelCallEvent to determine +when and how to retry failed model calls. +""" + +import asyncio +import logging +from typing import Any + +from ..hooks.events import AfterInvocationEvent, AfterModelCallEvent +from ..hooks.registry import HookProvider, HookRegistry +from ..types._events import EventLoopThrottleEvent, TypedEvent +from ..types.exceptions import ModelThrottledException + +logger = logging.getLogger(__name__) + + +class ModelRetryStrategy(HookProvider): + """Default retry strategy for model throttling with exponential backoff. + + Retries model calls on ModelThrottledException using exponential backoff. + Delay doubles after each attempt: initial_delay, initial_delay*2, initial_delay*4, + etc., capped at max_delay. State resets after successful calls. + + With defaults (initial_delay=4, max_delay=240, max_attempts=6), delays are: + 4s → 8s → 16s → 32s → 64s (5 retries before giving up on the 6th attempt). + + Args: + max_attempts: Total model attempts before re-raising the exception. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + max_delay: Upper bound in seconds for the exponential backoff. + """ + + def __init__( + self, + *, + max_attempts: int = 6, + initial_delay: int = 4, + max_delay: int = 240, + ): + """Initialize the retry strategy. + + Args: + max_attempts: Total model attempts before re-raising the exception. Defaults to 6. + initial_delay: Base delay in seconds; used for first two retries, then doubles. + Defaults to 4. + max_delay: Upper bound in seconds for the exponential backoff. Defaults to 240. + """ + self._max_attempts = max_attempts + self._initial_delay = initial_delay + self._max_delay = max_delay + self._current_attempt = 0 + self._backwards_compatible_event_to_yield: TypedEvent | None = None + + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register callbacks for AfterModelCallEvent and AfterInvocationEvent. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + registry.add_callback(AfterModelCallEvent, self._handle_after_model_call) + registry.add_callback(AfterInvocationEvent, self._handle_after_invocation) + + def _calculate_delay(self, attempt: int) -> int: + """Calculate retry delay using exponential backoff. + + Args: + attempt: The attempt number (0-indexed) to calculate delay for. + + Returns: + Delay in seconds for the given attempt. + """ + delay: int = self._initial_delay * (2**attempt) + return min(delay, self._max_delay) + + def _reset_retry_state(self) -> None: + """Reset retry state to initial values.""" + self._current_attempt = 0 + + async def _handle_after_invocation(self, event: AfterInvocationEvent) -> None: + """Reset retry state after invocation completes. + + Args: + event: The AfterInvocationEvent signaling invocation completion. + """ + self._reset_retry_state() + + async def _handle_after_model_call(self, event: AfterModelCallEvent) -> None: + """Handle model call completion and determine if retry is needed. + + This callback is invoked after each model call. If the call failed with + a ModelThrottledException and we haven't exceeded max_attempts, it sets + event.retry to True and sleeps for the current delay before returning. + + On successful calls, it resets the retry state to prepare for future calls. + + Args: + event: The AfterModelCallEvent containing call results or exception. + """ + delay = self._calculate_delay(self._current_attempt) + + self._backwards_compatible_event_to_yield = None + + # If already retrying, skip processing (another hook may have triggered retry) + if event.retry: + return + + # If model call succeeded, reset retry state + if event.stop_response is not None: + logger.debug( + "stop_reason=<%s> | model call succeeded, resetting retry state", + event.stop_response.stop_reason, + ) + self._reset_retry_state() + return + + # Check if we have an exception and reset state if no exception + if event.exception is None: + self._reset_retry_state() + return + + # Only retry on ModelThrottledException + if not isinstance(event.exception, ModelThrottledException): + return + + # Increment attempt counter first + self._current_attempt += 1 + + # Check if we've exceeded max attempts + if self._current_attempt >= self._max_attempts: + logger.debug( + "current_attempt=<%d>, max_attempts=<%d> | max retry attempts reached, not retrying", + self._current_attempt, + self._max_attempts, + ) + return + + self._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=delay) + + # Retry the model call + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered | delaying before next retry", + delay, + self._max_attempts, + self._current_attempt, + ) + + # Sleep for current delay + await asyncio.sleep(delay) + + # Set retry flag and track that this strategy triggered it + event.retry = True diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 99c8f5179..f5d00a201 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -8,7 +8,6 @@ 4. Manage recursive execution cycles """ -import asyncio import logging import uuid from collections.abc import AsyncGenerator @@ -23,7 +22,6 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..types._events import ( EventLoopStopEvent, - EventLoopThrottleEvent, ForceStopEvent, ModelMessageEvent, ModelStopReason, @@ -39,12 +37,12 @@ ContextWindowOverflowException, EventLoopException, MaxTokensReachedException, - ModelThrottledException, StructuredOutputException, ) from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached +from ._retry import ModelRetryStrategy from .streaming import stream_messages if TYPE_CHECKING: @@ -316,9 +314,9 @@ async def _handle_model_execution( stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): + # Retry loop - actual retry logic is handled by retry_strategy hook + # Hooks control when to stop retrying via the event.retry flag + while True: model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None model_invoke_span = tracer.start_model_invoke_span( messages=agent.messages, @@ -366,9 +364,8 @@ async def _handle_model_execution( # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "stop_reason=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, - attempt + 1, ) continue # Retry the model call @@ -389,34 +386,27 @@ async def _handle_model_execution( ) await agent.hooks.invoke_callbacks_async(after_model_call_event) + # Emit backwards-compatible events if retry strategy supports it + # (prior to making the retry strategy configurable, this is what we emitted) + + if ( + isinstance(agent._retry_strategy, ModelRetryStrategy) + and agent._retry_strategy._backwards_compatible_event_to_yield + ): + yield agent._retry_strategy._backwards_compatible_event_to_yield + # Check if hooks want to retry the model call if after_model_call_event.retry: logger.debug( - "exception=<%s>, retry_requested=, attempt=<%d> | hook requested model retry", + "exception=<%s>, retry_requested= | hook requested model retry", type(e).__name__, - attempt + 1, ) - continue # Retry the model call - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e - - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, - ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) + continue # Retry the model call - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + # No retry requested, raise the exception + yield ForceStopEvent(reason=e) + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/tests/strands/agent/conftest.py b/tests/strands/agent/conftest.py new file mode 100644 index 000000000..d3af90dc8 --- /dev/null +++ b/tests/strands/agent/conftest.py @@ -0,0 +1,22 @@ +"""Fixtures for agent tests.""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_sleep(monkeypatch): + """Mock asyncio.sleep to avoid delays in tests and track sleep calls.""" + sleep_calls = [] + + async def _mock_sleep(delay): + sleep_calls.append(delay) + + mock = AsyncMock(side_effect=_mock_sleep) + monkeypatch.setattr(asyncio, "sleep", mock) + + # Return both the mock and the sleep_calls list for verification + mock.sleep_calls = sleep_calls + return mock diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 7b189a5c6..f511c7019 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -1,6 +1,6 @@ import asyncio import unittest.mock -from unittest.mock import ANY, MagicMock, call +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pydantic import BaseModel @@ -34,9 +34,7 @@ async def streaming_tool(): @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -359,8 +357,8 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): {"arg1": 1013, "init_event_loop": True}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **throttle_props}, {"event_loop_throttled_delay": 8, **throttle_props}, - {"event_loop_throttled_delay": 16, **throttle_props}, {"event": {"messageStart": {"role": "assistant"}}}, {"event": {"redactContent": {"redactUserContentMessage": "BLOCKED!"}}}, {"event": {"contentBlockStart": {"start": {}}}}, @@ -508,11 +506,11 @@ async def test_event_loop_cycle_text_response_throttling_early_end( {"init_event_loop": True, "arg1": 1013}, {"start": True}, {"start_event_loop": True}, + {"event_loop_throttled_delay": 4, **common_props}, {"event_loop_throttled_delay": 8, **common_props}, {"event_loop_throttled_delay": 16, **common_props}, {"event_loop_throttled_delay": 32, **common_props}, {"event_loop_throttled_delay": 64, **common_props}, - {"event_loop_throttled_delay": 128, **common_props}, {"force_stop": True, "force_stop_reason": "ThrottlingException | ConverseStream"}, ] diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index be71b5fcf..e8b7e5077 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -104,7 +104,7 @@ class User(BaseModel): @pytest.fixture def mock_sleep(): - with patch.object(strands.event_loop.event_loop.asyncio, "sleep", new_callable=AsyncMock) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py new file mode 100644 index 000000000..1b3bc5e9c --- /dev/null +++ b/tests/strands/agent/test_agent_retry.py @@ -0,0 +1,161 @@ +"""Integration tests for Agent retry_strategy parameter.""" + +from unittest.mock import Mock + +import pytest + +from strands import Agent, ModelRetryStrategy +from strands.event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY +from strands.hooks import AfterModelCallEvent +from strands.types.exceptions import ModelThrottledException +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Agent Retry Strategy Initialization Tests + + +def test_agent_with_default_retry_strategy(): + """Test that Agent uses ModelRetryStrategy by default when retry_strategy=None.""" + agent = Agent() + + # Should have a retry_strategy + assert agent._retry_strategy is not None + + # Should be ModelRetryStrategy with default parameters + assert isinstance(agent._retry_strategy, ModelRetryStrategy) + assert agent._retry_strategy._max_attempts == 6 + assert agent._retry_strategy._initial_delay == 4 + assert agent._retry_strategy._max_delay == 240 + + +def test_agent_with_custom_model_retry_strategy(): + """Test Agent initialization with custom ModelRetryStrategy parameters.""" + custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + agent = Agent(retry_strategy=custom_strategy) + + assert agent._retry_strategy is custom_strategy + assert agent._retry_strategy._max_attempts == 3 + assert agent._retry_strategy._initial_delay == 2 + assert agent._retry_strategy._max_delay == 60 + + +def test_agent_rejects_invalid_retry_strategy_type(): + """Test that Agent raises ValueError for non-ModelRetryStrategy retry_strategy.""" + + class FakeRetryStrategy: + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=FakeRetryStrategy()) + + +def test_agent_rejects_subclass_of_model_retry_strategy(): + """Test that Agent rejects subclasses of ModelRetryStrategy (strict type check).""" + + class CustomRetryStrategy(ModelRetryStrategy): + pass + + with pytest.raises(ValueError, match="retry_strategy must be an instance of ModelRetryStrategy"): + Agent(retry_strategy=CustomRetryStrategy()) + + +def test_agent_default_retry_strategy_uses_event_loop_constants(): + """Test that default retry strategy uses constants from event_loop module.""" + agent = Agent() + + assert agent._retry_strategy._max_attempts == MAX_ATTEMPTS + assert agent._retry_strategy._initial_delay == INITIAL_DELAY + assert agent._retry_strategy._max_delay == MAX_DELAY + + +def test_retry_strategy_registered_as_hook(): + """Test that retry_strategy is registered with the hook system.""" + custom_strategy = ModelRetryStrategy(max_attempts=3) + agent = Agent(retry_strategy=custom_strategy) + + # Verify retry strategy callback is registered + callbacks = list(agent.hooks.get_callbacks_for(AfterModelCallEvent(agent=agent, exception=None))) + + # Should have at least one callback (from retry strategy) + assert len(callbacks) > 0 + + # Verify one of the callbacks is from the retry strategy + assert any( + callback.__self__ is custom_strategy if hasattr(callback, "__self__") else False for callback in callbacks + ) + + +# Agent Retry Behavior Tests + + +@pytest.mark.asyncio +async def test_agent_retries_with_default_strategy(mock_sleep): + """Test that Agent retries on throttling with default ModelRetryStrategy.""" + # Create a model that fails twice with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success after retries"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have succeeded after retries - just check we got events + assert len(events) > 0 + + # Should have slept twice (for two retries) + assert len(mock_sleep.sleep_calls) == 2 + # First retry: 4 seconds + assert mock_sleep.sleep_calls[0] == 4 + # Second retry: 8 seconds (exponential backoff) + assert mock_sleep.sleep_calls[1] == 8 + + +@pytest.mark.asyncio +async def test_agent_respects_max_attempts(mock_sleep): + """Test that Agent respects max_attempts in retry strategy.""" + # Create a model that always fails + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Use custom strategy with max 2 attempts + custom_strategy = ModelRetryStrategy(max_attempts=2, initial_delay=1, max_delay=60) + agent = Agent(model=model, retry_strategy=custom_strategy) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should have attempted max_attempts times, which means (max_attempts - 1) sleeps + # Attempt 0: fail, sleep + # Attempt 1: fail, no more attempts + assert len(mock_sleep.sleep_calls) == 1 + + +# Backwards Compatibility Tests + + +@pytest.mark.asyncio +async def test_event_loop_throttle_event_emitted(mock_sleep): + """Test that EventLoopThrottleEvent is still emitted for backwards compatibility.""" + # Create a model that fails once with throttling, then succeeds + model = Mock() + model.stream.side_effect = [ + ModelThrottledException("ThrottlingException"), + MockedModelProvider([{"role": "assistant", "content": [{"text": "Success"}]}]).stream([]), + ] + + agent = Agent(model=model) + + result = agent.stream_async("test prompt") + events = [event async for event in result] + + # Should have EventLoopThrottleEvent in the stream + throttle_events = [e for e in events if "event_loop_throttled_delay" in e] + assert len(throttle_events) > 0 + + # Should have the correct delay value + assert throttle_events[0]["event_loop_throttled_delay"] > 0 diff --git a/tests/strands/agent/test_retry.py b/tests/strands/agent/test_retry.py new file mode 100644 index 000000000..830c1b5b8 --- /dev/null +++ b/tests/strands/agent/test_retry.py @@ -0,0 +1,328 @@ +"""Unit tests for retry strategy implementations.""" + +from unittest.mock import Mock + +import pytest + +from strands import ModelRetryStrategy +from strands.hooks import AfterInvocationEvent, AfterModelCallEvent, HookRegistry +from strands.types._events import EventLoopThrottleEvent +from strands.types.exceptions import ModelThrottledException + +# ModelRetryStrategy Tests + + +def test_model_retry_strategy_init_with_defaults(): + """Test ModelRetryStrategy initialization with default parameters.""" + strategy = ModelRetryStrategy() + assert strategy._max_attempts == 6 + assert strategy._initial_delay == 4 + assert strategy._max_delay == 240 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_init_with_custom_parameters(): + """Test ModelRetryStrategy initialization with custom parameters.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + assert strategy._max_attempts == 3 + assert strategy._initial_delay == 2 + assert strategy._max_delay == 60 + assert strategy._current_attempt == 0 + + +def test_model_retry_strategy_calculate_delay_with_different_attempts(): + """Test _calculate_delay returns correct exponential backoff for different attempt numbers.""" + strategy = ModelRetryStrategy(initial_delay=2, max_delay=32) + + # Test exponential backoff: 2 * (2^attempt) + assert strategy._calculate_delay(0) == 2 # 2 * 2^0 = 2 + assert strategy._calculate_delay(1) == 4 # 2 * 2^1 = 4 + assert strategy._calculate_delay(2) == 8 # 2 * 2^2 = 8 + assert strategy._calculate_delay(3) == 16 # 2 * 2^3 = 16 + assert strategy._calculate_delay(4) == 32 # 2 * 2^4 = 32 (at max) + assert strategy._calculate_delay(5) == 32 # 2 * 2^5 = 64, capped at 32 + assert strategy._calculate_delay(10) == 32 # Large attempt, still capped + + +def test_model_retry_strategy_calculate_delay_respects_max_delay(): + """Test _calculate_delay respects max_delay cap.""" + strategy = ModelRetryStrategy(initial_delay=10, max_delay=50) + + assert strategy._calculate_delay(0) == 10 # 10 * 2^0 = 10 + assert strategy._calculate_delay(1) == 20 # 10 * 2^1 = 20 + assert strategy._calculate_delay(2) == 40 # 10 * 2^2 = 40 + assert strategy._calculate_delay(3) == 50 # 10 * 2^3 = 80, capped at 50 + assert strategy._calculate_delay(4) == 50 # 10 * 2^4 = 160, capped at 50 + + +def test_model_retry_strategy_register_hooks(): + """Test that ModelRetryStrategy registers AfterModelCallEvent and AfterInvocationEvent callbacks.""" + strategy = ModelRetryStrategy() + registry = HookRegistry() + + strategy.register_hooks(registry) + + # Verify AfterModelCallEvent callback was registered + assert AfterModelCallEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + # Verify AfterInvocationEvent callback was registered + assert AfterInvocationEvent in registry._registered_callbacks + assert len(registry._registered_callbacks[AfterInvocationEvent]) == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_retry_on_throttle_exception_first_attempt(mock_sleep): + """Test retry behavior on first ModelThrottledException.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should set retry to True + assert event.retry is True + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + # Should increment attempt + assert strategy._current_attempt == 1 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_exponential_backoff(mock_sleep): + """Test exponential backoff calculation.""" + strategy = ModelRetryStrategy(max_attempts=5, initial_delay=2, max_delay=16) + mock_agent = Mock() + + # Simulate multiple retries + for _ in range(4): + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event) + assert event.retry is True + + # Verify exponential backoff with max_delay cap + # attempt 0: 2*2^0=2, attempt 1: 2*2^1=4, attempt 2: 2*2^2=8, attempt 3: 2*2^3=16 (capped) + assert mock_sleep.sleep_calls == [2, 4, 8, 16] + for i, sleep_delay in enumerate(mock_sleep.sleep_calls): + assert sleep_delay == strategy._calculate_delay(i) + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_after_max_attempts(mock_sleep): + """Test that retry is not set after reaching max_attempts.""" + strategy = ModelRetryStrategy(max_attempts=2, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First attempt + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + + # Second attempt (at max_attempts) + event2 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event2) + # Should NOT retry after reaching max_attempts + assert event2.retry is False + assert strategy._current_attempt == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_non_throttle_exception(): + """Test that retry is not set for non-throttling exceptions.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ValueError("Some other error"), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on non-throttling exceptions + assert event.retry is False + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_on_success(): + """Test that retry is not set when model call succeeds.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should not retry on success + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_success(mock_sleep): + """Test that strategy resets attempt counter on successful call.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # First failure + event1 = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + await strategy._handle_after_model_call(event1) + assert event1.retry is True + assert strategy._current_attempt == 1 + # Should sleep for initial_delay (attempt 0: 2 * 2^0 = 2) + assert mock_sleep.sleep_calls == [2] + assert mock_sleep.sleep_calls[0] == strategy._calculate_delay(0) + + # Success - should reset + event2 = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + await strategy._handle_after_model_call(event2) + assert event2.retry is False + # Should reset to initial state + assert strategy._current_attempt == 0 + assert strategy._calculate_delay(0) == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_skips_if_already_retrying(): + """Test that strategy skips processing if event.retry is already True.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + # Simulate another hook already set retry to True + event.retry = True + + await strategy._handle_after_model_call(event) + + # Should not modify state since another hook already triggered retry + assert strategy._current_attempt == 0 + assert event.retry is True + + +@pytest.mark.asyncio +async def test_model_retry_strategy_reset_on_after_invocation(): + """Test that strategy resets state on AfterInvocationEvent.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Simulate some retry attempts + strategy._current_attempt = 3 + + event = AfterInvocationEvent(agent=mock_agent, result=Mock()) + await strategy._handle_after_invocation(event) + + # Should reset to initial state + assert strategy._current_attempt == 0 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_set_on_retry(mock_sleep): + """Test that _backwards_compatible_event_to_yield is set when retrying.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should have set the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is not None + assert isinstance(strategy._backwards_compatible_event_to_yield, EventLoopThrottleEvent) + assert strategy._backwards_compatible_event_to_yield["event_loop_throttled_delay"] == 2 + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_cleared_on_success(): + """Test that _backwards_compatible_event_to_yield is cleared on success.""" + strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) + mock_agent = Mock() + + # Set a previous backwards compatible event + strategy._backwards_compatible_event_to_yield = EventLoopThrottleEvent(delay=2) + + event = AfterModelCallEvent( + agent=mock_agent, + stop_response=AfterModelCallEvent.ModelStopResponse( + message={"role": "assistant", "content": [{"text": "Success"}]}, + stop_reason="end_turn", + ), + ) + + await strategy._handle_after_model_call(event) + + # Should have cleared the backwards compatible event + assert strategy._backwards_compatible_event_to_yield is None + + +@pytest.mark.asyncio +async def test_model_retry_strategy_backwards_compatible_event_not_set_on_max_attempts(mock_sleep): + """Test that _backwards_compatible_event_to_yield is not set when max attempts reached.""" + strategy = ModelRetryStrategy(max_attempts=1, initial_delay=2, max_delay=60) + mock_agent = Mock() + + event = AfterModelCallEvent( + agent=mock_agent, + exception=ModelThrottledException("Throttled"), + ) + + await strategy._handle_after_model_call(event) + + # Should not have set the backwards compatible event since max attempts reached + assert strategy._backwards_compatible_event_to_yield is None + assert event.retry is False + + +@pytest.mark.asyncio +async def test_model_retry_strategy_no_retry_when_no_exception_and_no_stop_response(): + """Test that retry is not set when there's no exception and no stop_response.""" + strategy = ModelRetryStrategy() + mock_agent = Mock() + + # Event with neither exception nor stop_response + event = AfterModelCallEvent( + agent=mock_agent, + exception=None, + stop_response=None, + ) + + await strategy._handle_after_model_call(event) + + # Should not retry and should reset state + assert event.retry is False + assert strategy._current_attempt == 0 diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 639e60ea0..d4afd579b 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import unittest.mock from unittest.mock import ANY, AsyncMock, MagicMock, call, patch @@ -7,6 +8,7 @@ import strands import strands.telemetry from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -31,9 +33,7 @@ @pytest.fixture def mock_sleep(): - with unittest.mock.patch.object( - strands.event_loop.event_loop.asyncio, "sleep", new_callable=unittest.mock.AsyncMock - ) as mock: + with patch.object(strands.event_loop._retry.asyncio, "sleep", new_callable=AsyncMock) as mock: yield mock @@ -116,7 +116,11 @@ def tool_stream(tool): @pytest.fixture def hook_registry(): - return HookRegistry() + registry = HookRegistry() + # Register default retry strategy + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry @pytest.fixture @@ -147,6 +151,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() return mock @@ -693,7 +698,7 @@ async def test_event_loop_tracing_with_throttling_exception( ] # Mock the time.sleep function to speed up the test - with patch("strands.event_loop.event_loop.asyncio.sleep", new_callable=unittest.mock.AsyncMock): + with patch.object(asyncio, "sleep", new_callable=unittest.mock.AsyncMock): stream = strands.event_loop.event_loop.event_loop_cycle( agent=agent, invocation_state={}, @@ -856,15 +861,21 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, # 1st call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 2nd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 3rd call - throttled assert next(events) == BeforeModelCallEvent(agent=agent) - assert next(events) == AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + expected_after.retry = True + assert next(events) == expected_after # 4th call - successful assert next(events) == BeforeModelCallEvent(agent=agent) From 7604e98bece0fe3fb0e0fcb5baa8055d69dcc422 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 21 Jan 2026 09:48:05 -0500 Subject: [PATCH 077/279] fix(swarm): accumulate execution_time across interrupt/resume cycles (#1502) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/multiagent/swarm.py | 4 ++-- tests/strands/multiagent/test_swarm.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 6c1149624..8368f5936 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -199,7 +199,7 @@ def should_continue( return False, f"Max iterations reached: {max_iterations}" # Check timeout - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -406,7 +406,7 @@ async def stream_async( self.state.completion_status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self.state.execution_time += round((time.time() - self.state.start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f2abed9f7..aae11b709 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1243,6 +1243,8 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): multiagent_result = swarm("Test task") + first_execution_time = multiagent_result.execution_time + tru_status = multiagent_result.status exp_status = Status.INTERRUPTED assert tru_status == exp_status @@ -1279,6 +1281,8 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + assert multiagent_result.execution_time >= first_execution_time + def test_swarm_interrupt_on_agent(agenerator): exp_interrupts = [ From 2e23d755ecd438c92082103ec941200e011cadc8 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 21 Jan 2026 10:45:48 -0500 Subject: [PATCH 078/279] Feat: graduate multiagent hook events from experimental (#1498) --- .../experimental/hooks/multiagent/__init__.py | 4 +- .../experimental/hooks/multiagent/events.py | 138 +++--------------- src/strands/hooks/__init__.py | 11 ++ src/strands/hooks/events.py | 106 +++++++++++++- src/strands/multiagent/graph.py | 4 +- src/strands/multiagent/swarm.py | 4 +- src/strands/session/session_manager.py | 6 +- .../fixtures/mock_multiagent_hook_provider.py | 6 +- .../experimental/hooks/multiagent/__init__.py | 0 .../hooks/multiagent => hooks}/test_events.py | 4 +- .../test_multi_agent_hooks.py | 2 +- tests/strands/multiagent/conftest.py | 3 +- tests/strands/multiagent/test_graph.py | 3 +- tests/strands/multiagent/test_swarm.py | 2 +- tests_integ/hooks/multiagent/test_cancel.py | 3 +- tests_integ/hooks/multiagent/test_events.py | 4 +- .../interrupts/multiagent/test_hook.py | 3 +- .../interrupts/multiagent/test_session.py | 3 +- tests_integ/test_multiagent_swarm.py | 2 +- 19 files changed, 164 insertions(+), 144 deletions(-) delete mode 100644 tests/strands/experimental/hooks/multiagent/__init__.py rename tests/strands/{experimental/hooks/multiagent => hooks}/test_events.py (97%) rename tests/strands/{experimental/hooks/multiagent => hooks}/test_multi_agent_hooks.py (98%) diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py index d059d0da5..6755db7e4 100644 --- a/src/strands/experimental/hooks/multiagent/__init__.py +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -1,6 +1,6 @@ -"""Multi-agent hook events and utilities. +"""Multi-agent hook events. -Provides event classes for hooking into multi-agent orchestrator lifecycle. +Deprecated: Use strands.hooks.multiagent instead. """ from .events import ( diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index fa881bf32..2c65c53e3 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -1,118 +1,28 @@ """Multi-agent execution lifecycle events for hook system integration. -These events are fired by orchestrators (Graph/Swarm) at key points so -hooks can persist, monitor, or debug execution. No intermediate state model -is used—hooks read from the orchestrator directly. +Deprecated: Use strands.hooks.multiagent instead. """ -import uuid -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -from typing_extensions import override - -from ....hooks import BaseHookEvent -from ....types.interrupt import _Interruptible - -if TYPE_CHECKING: - from ....multiagent.base import MultiAgentBase - - -@dataclass -class MultiAgentInitializedEvent(BaseHookEvent): - """Event triggered when multi-agent orchestrator initialized. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): - """Event triggered before individual node execution starts. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node about to execute - invocation_state: Configuration that user passes in - cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. - The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the - node using a default cancel message. - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - cancel_node: bool | str = False - - def _can_write(self, name: str) -> bool: - return name in ["cancel_node"] - - @override - def _interrupt_id(self, name: str) -> str: - """Unique id for the interrupt. - - Args: - name: User defined name for the interrupt. - - Returns: - Interrupt id. - """ - node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) - call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) - return f"v1:before_node_call:{node_id}:{call_id}" - - -@dataclass -class AfterNodeCallEvent(BaseHookEvent): - """Event triggered after individual node execution completes. - - Attributes: - source: The multi-agent orchestrator instance - node_id: ID of the node that just completed execution - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - node_id: str - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True - - -@dataclass -class BeforeMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered before orchestrator execution starts. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class AfterMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered after orchestrator execution completes. - - Attributes: - source: The multi-agent orchestrator instance - invocation_state: Configuration that user passes in - """ - - source: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - @property - def should_reverse_callbacks(self) -> bool: - """True to invoke callbacks in reverse order.""" - return True +import warnings + +from ....hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +warnings.warn( + "strands.experimental.hooks.multiagent is deprecated. Use strands.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 30163f207..96c7f577b 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -32,12 +32,18 @@ def log_end(self, event: AfterInvocationEvent) -> None: from .events import ( AfterInvocationEvent, AfterModelCallEvent, + # Multiagent hook events + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, + MultiAgentInitializedEvent, ) from .registry import BaseHookEvent, HookCallback, HookEvent, HookProvider, HookRegistry @@ -56,4 +62,9 @@ def log_end(self, event: AfterInvocationEvent) -> None: "HookRegistry", "HookEvent", "BaseHookEvent", + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", ] diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8aa8a68d6..1faa8a917 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -16,7 +16,10 @@ from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse -from .registry import HookEvent +from .registry import BaseHookEvent, HookEvent + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase @dataclass @@ -250,3 +253,104 @@ def _can_write(self, name: str) -> bool: def should_reverse_callbacks(self) -> bool: """True to invoke callbacks in reverse order.""" return True + + +# Multiagent hook events start here +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): + """Event triggered before individual node execution starts. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node about to execute + invocation_state: Configuration that user passes in + cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. + The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the + node using a default cancel message. + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + cancel_node: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_node"] + + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + node_id = uuid.uuid5(uuid.NAMESPACE_OID, self.node_id) + call_id = uuid.uuid5(uuid.NAMESPACE_OID, name) + return f"v1:before_node_call:{node_id}:{call_id}" + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution starts. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 97435ad4a..32eca00ff 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -27,14 +27,14 @@ from .._async import run_async from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 8368f5936..9a4ce5494 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -27,14 +27,14 @@ from .._async import run_async from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent import ( +from ..hooks.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks import HookProvider, HookRegistry +from ..hooks.registry import HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index ba4356089..cc954e17d 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -9,12 +9,14 @@ BidiAgentInitializedEvent, BidiMessageAddedEvent, ) -from ..experimental.hooks.multiagent.events import ( +from ..hooks.events import ( + AfterInvocationEvent, AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + AgentInitializedEvent, + MessageAddedEvent, MultiAgentInitializedEvent, ) -from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 4d18297a2..a89d5aca8 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,16 +1,14 @@ from collections.abc import Iterator from typing import Literal -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, - MultiAgentInitializedEvent, -) -from strands.hooks import ( HookEvent, HookProvider, HookRegistry, + MultiAgentInitializedEvent, ) diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/strands/experimental/hooks/multiagent/test_events.py b/tests/strands/hooks/test_events.py similarity index 97% rename from tests/strands/experimental/hooks/multiagent/test_events.py rename to tests/strands/hooks/test_events.py index 6c4d7c4e7..90ab205a9 100644 --- a/tests/strands/experimental/hooks/multiagent/test_events.py +++ b/tests/strands/hooks/test_events.py @@ -4,14 +4,14 @@ import pytest -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + BaseHookEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from strands.hooks import BaseHookEvent @pytest.fixture diff --git a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py b/tests/strands/hooks/test_multi_agent_hooks.py similarity index 98% rename from tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py rename to tests/strands/hooks/test_multi_agent_hooks.py index 4e97a9217..3f6e0c940 100644 --- a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py +++ b/tests/strands/hooks/test_multi_agent_hooks.py @@ -1,7 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent.events import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py index 85e0ef7fc..e5dd1b4f9 100644 --- a/tests/strands/multiagent/conftest.py +++ b/tests/strands/multiagent/conftest.py @@ -1,7 +1,6 @@ import pytest -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider @pytest.fixture diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index ab2d86e70..cd750865e 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -6,8 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import AgentInitializedEvent +from strands.hooks import AgentInitializedEvent, BeforeNodeCallEvent from strands.hooks.registry import HookProvider, HookRegistry from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index aae11b709..75ef97a25 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -6,7 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import BeforeNodeCallEvent from strands.hooks.registry import HookRegistry from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import Status diff --git a/tests_integ/hooks/multiagent/test_cancel.py b/tests_integ/hooks/multiagent/test_cancel.py index 9267330b7..ae3008861 100644 --- a/tests_integ/hooks/multiagent/test_cancel.py +++ b/tests_integ/hooks/multiagent/test_cancel.py @@ -1,8 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types._events import MultiAgentNodeCancelEvent diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py index e8039444f..3a10b74c1 100644 --- a/tests_integ/hooks/multiagent/test_events.py +++ b/tests_integ/hooks/multiagent/test_events.py @@ -1,14 +1,14 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent import ( +from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + HookProvider, MultiAgentInitializedEvent, ) -from strands.hooks import HookProvider from strands.multiagent import GraphBuilder, Swarm diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index 9350b3535..53305b4e8 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -4,8 +4,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index bab4b428f..2ccff2c12 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,8 +4,7 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent -from strands.hooks import HookProvider +from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e8e969af1..e9738d3d9 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -3,13 +3,13 @@ import pytest from strands import Agent, tool -from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import ( AfterInvocationEvent, AfterModelCallEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeNodeCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) From b41a99bedaca93b55cb57262aeb5c109f6b2a688 Mon Sep 17 00:00:00 2001 From: Lana Zhang Date: Wed, 21 Jan 2026 11:07:18 -0500 Subject: [PATCH 079/279] Nova Sonic 2 support for BidiAgent (#1476) --- README.md | 28 ++- .../experimental/bidi/models/nova_sonic.py | 142 +++++++++++++-- .../bidi/models/test_nova_sonic.py | 165 ++++++++++++++++++ tests_integ/bidi/test_bidirectional_agent.py | 9 +- 4 files changed, 328 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e7d1b2a7e..8e4d9d0e8 100644 --- a/README.md +++ b/README.md @@ -204,9 +204,9 @@ It's also available on GitHub via [strands-agents/tools](https://github.com/stra Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. **Supported Model Providers:** -- Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) -- Google Gemini Live (`gemini-2.5-flash-native-audio-preview-09-2025`) -- OpenAI Realtime API (`gpt-realtime`) +- Amazon Nova Sonic (v1, v2) +- Google Gemini Live +- OpenAI Realtime API **Quick Example:** @@ -219,7 +219,7 @@ from strands.experimental.bidi.tools import stop_conversation from strands_tools import calculator async def main(): - # Create bidirectional agent with audio model + # Create bidirectional agent with Nova Sonic v2 model = BidiNovaSonicModel() agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) @@ -241,7 +241,9 @@ if __name__ == "__main__": **Configuration Options:** ```python -# Configure audio settings +from strands.experimental.bidi.models import BidiNovaSonicModel + +# Configure audio settings and turn detection (v2 only) model = BidiNovaSonicModel( provider_config={ "audio": { @@ -249,6 +251,9 @@ model = BidiNovaSonicModel( "output_rate": 16000, "voice": "matthew" }, + "turn_detection": { + "endpointingSensitivity": "MEDIUM" # HIGH, MEDIUM, or LOW + }, "inference": { "max_tokens": 2048, "temperature": 0.7 @@ -263,6 +268,19 @@ audio_io = BidiAudioIO( input_buffer_size=10, output_buffer_size=10 ) + +# Text input mode (type messages instead of speaking) +text_io = BidiTextIO() +await agent.run( + inputs=[text_io.input()], # Use text input + outputs=[audio_io.output(), text_io.output()] +) + +# Multi-modal: Both audio and text input +await agent.run( + inputs=[audio_io.input(), text_io.input()], # Speak OR type + outputs=[audio_io.output(), text_io.output()] +) ``` ## Documentation diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 1c946220d..d836bde49 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -64,6 +64,10 @@ logger = logging.getLogger(__name__) +# Nova Sonic model identifiers +NOVA_SONIC_V1_MODEL_ID = "amazon.nova-sonic-v1:0" +NOVA_SONIC_V2_MODEL_ID = "amazon.nova-2-sonic-v1:0" + _NOVA_INFERENCE_CONFIG_KEYS = { "max_tokens": "maxTokens", "temperature": "temperature", @@ -110,7 +114,7 @@ class BidiNovaSonicModel(BidiModel): def __init__( self, - model_id: str = "amazon.nova-sonic-v1:0", + model_id: str = NOVA_SONIC_V2_MODEL_ID, provider_config: dict[str, Any] | None = None, client_config: dict[str, Any] | None = None, **kwargs: Any, @@ -118,19 +122,41 @@ def __init__( """Initialize Nova Sonic bidirectional model. Args: - model_id: Model identifier (default: amazon.nova-sonic-v1:0) - provider_config: Model behavior (audio, inference settings) + model_id: Model identifier (default: amazon.nova-2-sonic-v1:0) + provider_config: Model behavior configuration including: + - audio: Audio input/output settings (sample rate, voice, etc.) + - inference: Model inference settings (max_tokens, temperature, top_p) + - turn_detection: Turn detection configuration (v2 only feature) + - endpointingSensitivity: "HIGH" | "MEDIUM" | "LOW" (optional) client_config: AWS authentication (boto_session OR region, not both) **kwargs: Reserved for future parameters. + + Raises: + ValueError: If turn_detection is used with v1 model. + ValueError: If endpointingSensitivity is not HIGH, MEDIUM, or LOW. """ # Store model ID self.model_id = model_id + # Validate turn_detection configuration + provider_config = provider_config or {} + if "turn_detection" in provider_config and provider_config["turn_detection"]: + if model_id == NOVA_SONIC_V1_MODEL_ID: + raise ValueError( + f"turn_detection is only supported in Nova Sonic v2. " + f"Current model_id: {model_id}. Use {NOVA_SONIC_V2_MODEL_ID} instead." + ) + + # Validate endpointingSensitivity value if provided + sensitivity = provider_config["turn_detection"].get("endpointingSensitivity") + if sensitivity and sensitivity not in ["HIGH", "MEDIUM", "LOW"]: + raise ValueError(f"Invalid endpointingSensitivity: {sensitivity}. Must be HIGH, MEDIUM, or LOW") + # Resolve client config with defaults self._client_config = self._resolve_client_config(client_config or {}) # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) + self.config = self._resolve_provider_config(provider_config) # Store session and region for later use self._session = self._client_config["boto_session"] @@ -182,6 +208,7 @@ def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: **config.get("audio", {}), }, "inference": config.get("inference", {}), + "turn_detection": config.get("turn_detection", {}), } return resolved @@ -269,21 +296,57 @@ def _build_initialization_events( def _log_event_type(self, nova_event: dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" + # Log the full event structure for detailed debugging + event_keys = list(nova_event.keys()) + logger.debug("event_keys=<%s> | nova sonic event received", event_keys) + if "usageEvent" in nova_event: - logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + usage = nova_event["usageEvent"] + logger.debug( + "input_tokens=<%s>, output_tokens=<%s>, usage_details=<%s> | nova usage event", + usage.get("totalInputTokens", 0), + usage.get("totalOutputTokens", 0), + json.dumps(usage, indent=2), + ) elif "textOutput" in nova_event: - logger.debug("nova text output received") + text_content = nova_event["textOutput"].get("content", "") + logger.debug( + "text_length=<%d>, text_preview=<%s>, text_output_details=<%s> | nova text output", + len(text_content), + text_content[:100], + json.dumps(nova_event["textOutput"], indent=2)[:500], + ) elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + "tool_name=<%s>, tool_use_id=<%s>, tool_use_details=<%s> | nova tool use received", tool_use["toolName"], tool_use["toolUseId"], + json.dumps(tool_use, indent=2)[:500], ) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + elif "completionStart" in nova_event: + completion_id = nova_event["completionStart"].get("completionId", "unknown") + logger.debug("completion_id=<%s> | nova completion started", completion_id) + elif "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + logger.debug( + "completion_id=<%s>, stop_reason=<%s> | nova completion ended", + completion_data.get("completionId", "unknown"), + completion_data.get("stopReason", "unknown"), + ) + elif "stopReason" in nova_event: + logger.debug("stop_reason=<%s> | nova stop reason event", nova_event["stopReason"]) + else: + # Log any other event types + audio_metadata = self._get_audio_metadata_for_logging({"event": nova_event}) + if audio_metadata: + logger.debug("audio_byte_count=<%d> | nova sonic event with audio", audio_metadata["audio_byte_count"]) + else: + logger.debug("event_payload=<%s> | nova sonic event details", json.dumps(nova_event, indent=2)[:500]) async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: """Receive Nova Sonic events and convert to provider-agnostic format. @@ -312,14 +375,25 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: raise BidiModelTimeoutError(error.message) from error if not event_data: + logger.debug("received empty event data, continuing") continue - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + # Decode and parse the event + raw_bytes = event_data.value.bytes_.decode("utf-8") + logger.debug("raw_event_size=<%d> | received nova sonic event", len(raw_bytes)) + + nova_event = json.loads(raw_bytes)["event"] self._log_event_type(nova_event) model_event = self._convert_nova_event(nova_event) if model_event: + event_type = ( + model_event.get("type", "unknown") if isinstance(model_event, dict) else type(model_event).__name__ + ) + logger.debug("converted_event_type=<%s> | yielding converted event", event_type) yield model_event + else: + logger.debug("event_not_converted | nova event did not produce output event") async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: """Unified send method for all content types. Sends the given content to Nova Sonic. @@ -336,14 +410,24 @@ async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: raise RuntimeError("model not started | call start before sending") if isinstance(content, BidiTextInputEvent): + text_preview = content.text[:100] if len(content.text) > 100 else content.text + logger.debug("text_length=<%d>, text_preview=<%s> | sending text content", len(content.text), text_preview) await self._send_text_content(content.text) elif isinstance(content, BidiAudioInputEvent): + audio_size = len(base64.b64decode(content.audio)) if content.audio else 0 + logger.debug("audio_bytes=<%d>, format=<%s> | sending audio content", audio_size, content.format) await self._send_audio_content(content) elif isinstance(content, ToolResultEvent): tool_result = content.get("tool_result") if tool_result: + logger.debug( + "tool_use_id=<%s>, content_blocks=<%d> | sending tool result", + tool_result.get("toolUseId", "unknown"), + len(tool_result.get("content", [])), + ) await self._send_tool_result(tool_result) else: + logger.error("content_type=<%s> | unsupported content type", type(content)) raise ValueError(f"content_type={type(content)} | content not supported") async def _start_audio_connection(self) -> None: @@ -583,7 +667,15 @@ def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | N def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) + + session_start_event: dict[str, Any] = {"event": {"sessionStart": {"inferenceConfiguration": inference_config}}} + + # Add turn detection configuration if provided (v2 feature) + turn_detection_config = self.config.get("turn_detection", {}) + if turn_detection_config: + session_start_event["event"]["sessionStart"]["turnDetectionConfiguration"] = turn_detection_config + + return json.dumps(session_start_event) def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" @@ -749,6 +841,37 @@ def _get_connection_end_event(self) -> str: """Generate connection end event.""" return json.dumps({"event": {"connectionEnd": {}}}) + def _get_audio_metadata_for_logging(self, event_dict: dict[str, Any]) -> dict[str, Any]: + """Extract audio metadata from event dict for logging. + + Instead of logging large base64-encoded audio data, this extracts metadata + like byte count to verify audio presence without bloating logs. + + Args: + event_dict: The event dictionary to process. + + Returns: + A dict with audio metadata (byte_count) if audio is present, empty dict otherwise. + """ + metadata: dict[str, Any] = {} + + if "event" in event_dict: + event_data = event_dict["event"] + + # Handle contentStart events with audio + if "contentStart" in event_data and "content" in event_data["contentStart"]: + content = event_data["contentStart"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + # Handle content events with audio + if "content" in event_data and "content" in event_data["content"]: + content = event_data["content"]["content"] + if "audio" in content and "bytes" in content["audio"]: + metadata["audio_byte_count"] = len(content["audio"]["bytes"]) + + return metadata + async def _send_nova_events(self, events: list[str]) -> None: """Send event JSON string to Nova Sonic stream. @@ -764,4 +887,3 @@ async def _send_nova_events(self, events: list[str]) -> None: value=BidirectionalInputPayloadPart(bytes_=bytes_data) ) await self._stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 7435d4ad2..14630875b 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -23,6 +23,8 @@ from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, + NOVA_SONIC_V1_MODEL_ID, + NOVA_SONIC_V2_MODEL_ID, ) from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -579,6 +581,169 @@ async def test_default_audio_rates_in_events(model_id, boto_session): assert result.format == "pcm" +# Nova Sonic v2 Support Tests + + +def test_nova_sonic_model_constants(): + """Test that Nova Sonic model ID constants are correctly defined.""" + assert NOVA_SONIC_V1_MODEL_ID == "amazon.nova-sonic-v1:0" + assert NOVA_SONIC_V2_MODEL_ID == "amazon.nova-2-sonic-v1:0" + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v1 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V1_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "joanna", "output_rate": 24000}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_custom.config["audio"]["voice"] == "joanna" + assert model_custom.config["audio"]["output_rate"] == 24000 + + +@pytest.mark.asyncio +async def test_nova_sonic_v2_instantiation(boto_session, mock_client): + """Test direct instantiation with Nova Sonic v2 model ID.""" + _ = mock_client # Ensure mock is active + + # Test default creation + model = BidiNovaSonicModel(model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session}) + assert model.model_id == NOVA_SONIC_V2_MODEL_ID + assert model.region == "us-east-1" + + # Test with custom config + provider_config = {"audio": {"voice": "ruth", "input_rate": 48000}, "inference": {"temperature": 0.8}} + client_config = {"boto_session": boto_session} + model_custom = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + assert model_custom.model_id == NOVA_SONIC_V2_MODEL_ID + assert model_custom.config["audio"]["voice"] == "ruth" + assert model_custom.config["audio"]["input_rate"] == 48000 + assert model_custom.config["inference"]["temperature"] == 0.8 + + +@pytest.mark.asyncio +async def test_nova_sonic_v1_v2_compatibility(boto_session, mock_client): + """Test that v1 and v2 models have the same config structure and behavior.""" + _ = mock_client # Ensure mock is active + + # Create both models with same config + provider_config = {"audio": {"voice": "matthew"}} + client_config = {"boto_session": boto_session} + + model_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, provider_config=provider_config, client_config=client_config + ) + + # Both should have the same config structure + assert model_v1.config["audio"]["voice"] == model_v2.config["audio"]["voice"] + assert model_v1.region == model_v2.region + + # Only model_id should differ + assert model_v1.model_id != model_v2.model_id + assert model_v1.model_id == NOVA_SONIC_V1_MODEL_ID + assert model_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_backward_compatibility(boto_session, mock_client): + """Test that existing code continues to work (backward compatibility).""" + _ = mock_client # Ensure mock is active + + # Test that default behavior now uses v2 (updated default) + model_default = BidiNovaSonicModel(client_config={"boto_session": boto_session}) + assert model_default.model_id == NOVA_SONIC_V2_MODEL_ID + + # Test that existing explicit v1 usage still works + model_explicit_v1 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v1.model_id == NOVA_SONIC_V1_MODEL_ID + + # Test that explicit v2 usage works + model_explicit_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, client_config={"boto_session": boto_session} + ) + assert model_explicit_v2.model_id == NOVA_SONIC_V2_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_v1_validation(boto_session, mock_client): + """Test that turn_detection raises error when used with v1 model.""" + _ = mock_client # Ensure mock is active + + # Test that turn_detection with v1 raises ValueError + with pytest.raises(ValueError, match="turn_detection is only supported in Nova Sonic v2"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + + # Test that turn_detection with v2 works fine + model_v2 = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "MEDIUM"}}, + client_config={"boto_session": boto_session}, + ) + assert model_v2.config["turn_detection"]["endpointingSensitivity"] == "MEDIUM" + + # Test that empty turn_detection dict doesn't raise error for v1 + model_v1_empty = BidiNovaSonicModel( + model_id=NOVA_SONIC_V1_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert model_v1_empty.model_id == NOVA_SONIC_V1_MODEL_ID + + +@pytest.mark.asyncio +async def test_turn_detection_sensitivity_validation(boto_session, mock_client): + """Test that endpointingSensitivity is validated at initialization.""" + _ = mock_client # Ensure mock is active + + # Test invalid sensitivity value raises ValueError at init + with pytest.raises(ValueError, match="Invalid endpointingSensitivity.*Must be HIGH, MEDIUM, or LOW"): + BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": "INVALID"}}, + client_config={"boto_session": boto_session}, + ) + + # Test valid sensitivity values work + for sensitivity in ["HIGH", "MEDIUM", "LOW"]: + model = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {"endpointingSensitivity": sensitivity}}, + client_config={"boto_session": boto_session}, + ) + assert model.config["turn_detection"]["endpointingSensitivity"] == sensitivity + + # Test that turn_detection without sensitivity works (sensitivity is optional) + model_no_sensitivity = BidiNovaSonicModel( + model_id=NOVA_SONIC_V2_MODEL_ID, + provider_config={"turn_detection": {}}, + client_config={"boto_session": boto_session}, + ) + assert "endpointingSensitivity" not in model_no_sensitivity.config["turn_detection"] + + # Error Handling Tests @pytest.mark.asyncio async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py index 61cf78723..243db46ac 100644 --- a/tests_integ/bidi/test_bidirectional_agent.py +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -55,11 +55,18 @@ def calculator(operation: str, x: float, y: float) -> float: PROVIDER_CONFIGS = { "nova_sonic": { "model_class": BidiNovaSonicModel, - "model_kwargs": {"region": "us-east-1"}, + "model_kwargs": {"region": "us-east-1"}, # Uses v2 by default "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], "skip_reason": "AWS credentials not available", }, + "nova_sonic_v1": { + "model_class": BidiNovaSonicModel, + "model_kwargs": {"model_id": "amazon.nova-sonic-v1:0", "region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic v1 needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, "openai": { "model_class": BidiOpenAIRealtimeModel, "model_kwargs": { From f87925b9383a8ead59f0138e55e8412478c70928 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 21 Jan 2026 18:09:29 +0200 Subject: [PATCH 080/279] fix(tests): reduce flakiness in guardrail redact output test (#1505) --- pyproject.toml | 1 + tests_integ/conftest.py | 116 +++++++++++++++++++++++++ tests_integ/test_bedrock_guardrails.py | 13 ++- 3 files changed, 126 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b49c74d1b..a16132881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dev = [ "pytest-asyncio>=1.0.0,<1.4.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.15.0", + "tenacity>=9.0.0,<10.0.0", ] [project.urls] diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 26453e1f7..9de00089b 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -1,13 +1,129 @@ +import functools import json import logging import os +from collections.abc import Callable, Sequence import boto3 import pytest +from tenacity import RetryCallState, RetryError, Retrying, stop_after_attempt, wait_exponential logger = logging.getLogger(__name__) +# Type alias for retry conditions +RetryCondition = type[BaseException] | Callable[[BaseException], bool] | str + + +def _should_retry_exception(exc: BaseException, conditions: Sequence[RetryCondition]) -> bool: + """Check if exception matches any of the given retry conditions. + + Args: + exc: The exception to check + conditions: Sequence of conditions, each can be: + - Exception type: retry if isinstance(exc, condition) + - Callable: retry if condition(exc) returns True + - str: retry if string is in str(exc) + """ + for condition in conditions: + if isinstance(condition, type) and issubclass(condition, BaseException): + if isinstance(exc, condition): + return True + elif callable(condition): + if condition(exc): + return True + elif isinstance(condition, str): + if condition in str(exc): + return True + return False + + +_RETRY_ON_ANY: Sequence[RetryCondition] = (lambda _: True,) + + +def retry_on_flaky( + reason: str, + *, + max_attempts: int = 3, + wait_multiplier: float = 1, + wait_max: float = 10, + retry_on: Sequence[RetryCondition] = _RETRY_ON_ANY, +) -> Callable: + """Decorator to retry flaky integration tests that fail due to external factors. + + WHEN TO USE: + - External service instability (API rate limits, transient network errors) + - Non-deterministic LLM responses that occasionally fail assertions + - Resource contention in shared test environments + - Known intermittent issues with third-party dependencies + + WHEN NOT TO USE: + - Actual bugs in the code under test (fix the bug instead) + - Deterministic failures (these indicate real problems) + - Unit tests (flakiness in unit tests usually indicates a design issue) + - To mask consistently failing tests (investigate root cause first) + + Prefer using specific retry_on conditions over retrying on any exception + to avoid masking real bugs. + + Args: + reason: Required explanation of why this test is flaky and needs retries. + This should describe the source of non-determinism (e.g., "LLM responses + may vary" or "External API has intermittent rate limits"). + max_attempts: Maximum number of retry attempts (default: 3) + wait_multiplier: Multiplier for exponential backoff in seconds (default: 1) + wait_max: Maximum wait time between retries in seconds (default: 10) + retry_on: Conditions for when to retry. Defaults to retrying on any exception. + Each condition can be: + - Exception type: e.g., ValueError, TimeoutError + - Callable: e.g., lambda e: "timeout" in str(e).lower() + - str: substring to match in exception message + + Usage: + # Retry on any failure + @retry_on_flaky("LLM responses are non-deterministic") + def test_something(): + ... + + # Retry only on specific exception types + @retry_on_flaky("Network calls may fail transiently", retry_on=[TimeoutError, ConnectionError]) + def test_network_call(): + ... + + # Retry on string patterns in exception message + @retry_on_flaky("Service has intermittent availability", retry_on=["Service unavailable", "Status 503"]) + def test_service_call(): + ... + """ + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + def should_retry(retry_state: RetryCallState) -> bool: + if retry_state.outcome is None or not retry_state.outcome.failed: + return False + exc = retry_state.outcome.exception() + if exc is None: + return False + return _should_retry_exception(exc, retry_on) + + try: + for attempt in Retrying( + stop=stop_after_attempt(max_attempts), + wait=wait_exponential(multiplier=wait_multiplier, max=wait_max), + retry=should_retry, + reraise=True, + ): + with attempt: + return func(*args, **kwargs) + except RetryError: + raise + + return wrapper + + return decorator + + def pytest_sessionstart(session): _load_api_keys_from_secrets_manager() diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 058597026..56edc3fc4 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -8,6 +8,7 @@ from strands import Agent, tool from strands.models.bedrock import BedrockModel from strands.session.file_session_manager import FileSessionManager +from tests_integ.conftest import retry_on_flaky BLOCKED_INPUT = "BLOCKED_INPUT" BLOCKED_OUTPUT = "BLOCKED_OUTPUT" @@ -170,9 +171,11 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi ) +@retry_on_flaky("LLM may mention CACTUS unprompted, triggering guardrail on response2") @pytest.mark.parametrize("guardrail_trace", ["enabled", "enabled_full"]) @pytest.mark.parametrize("processing_mode", ["sync", "async"]) def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode, guardrail_trace): + """Test guardrail output intervention with redaction.""" REDACT_MESSAGE = "Redacted." bedrock_model = BedrockModel( guardrail_id=bedrock_guardrail, @@ -182,23 +185,25 @@ def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processi guardrail_redact_output=True, guardrail_redact_output_message=REDACT_MESSAGE, region_name="us-east-1", + temperature=0, # Use deterministic responses to reduce flakiness ) agent = Agent( model=bedrock_model, - system_prompt="When asked to say the word, say CACTUS.", + system_prompt="When asked to say the word, say CACTUS. Otherwise, respond normally.", callback_handler=None, load_tools_from_directory=False, ) response1 = agent("Say the word.") - response2 = agent("Hello!") + # Use a completely unrelated prompt to reduce likelihood of model volunteering CACTUS + response2 = agent("What is 2+2? Reply with only the number.") assert response1.stop_reason == "guardrail_intervened" """ - In async streaming: The buffering is non-blocking. - Tokens are streamed while Guardrails processes the buffered content in the background. + In async streaming: The buffering is non-blocking. + Tokens are streamed while Guardrails processes the buffered content in the background. This means the response may be returned before Guardrails has finished processing. As a result, we cannot guarantee that the REDACT_MESSAGE is in the response. """ From 78a1c28b46241f9be5f4ebf4cfc5986df1cbb2f5 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 21 Jan 2026 23:00:55 +0200 Subject: [PATCH 081/279] test: fix flaky openai structured output test by adding Field guidance (#1534) --- tests_integ/models/test_model_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 503fca898..99ac49148 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -45,10 +45,10 @@ def agent(model, tools): @pytest.fixture def weather(): class Weather(pydantic.BaseModel): - """Extracts the time and weather from the user's message with the exact strings.""" + """Extract time and weather values.""" - time: str - weather: str + time: str = pydantic.Field(description="The time value only, e.g. '14:30' not 'The time is 14:30'") + weather: str = pydantic.Field(description="The weather condition only, e.g. 'rainy' not 'the weather is rainy'") return Weather(time="12:00", weather="sunny") From 70b1d10fe8553ca59ed2d422092bf483de179dbd Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 22 Jan 2026 10:14:06 -0500 Subject: [PATCH 082/279] interrupts - multiagent - do not emit AfterNodeCallEvent on interrupt (#1539) --- src/strands/multiagent/graph.py | 3 ++- src/strands/multiagent/swarm.py | 7 ++++--- tests/strands/multiagent/conftest.py | 9 ++++++++- tests/strands/multiagent/test_graph.py | 8 ++++++++ tests/strands/multiagent/test_swarm.py | 8 ++++++++ 5 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 32eca00ff..bad7eede9 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -1005,7 +1005,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) + if node.execution_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 9a4ce5494..10e0da515 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -782,9 +782,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato break finally: - await self.hooks.invoke_callbacks_async( - AfterNodeCallEvent(self, current_node.node_id, invocation_state) - ) + if self.state.completion_status != Status.INTERRUPTED: + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) diff --git a/tests/strands/multiagent/conftest.py b/tests/strands/multiagent/conftest.py index e5dd1b4f9..190dc4a91 100644 --- a/tests/strands/multiagent/conftest.py +++ b/tests/strands/multiagent/conftest.py @@ -1,15 +1,22 @@ import pytest -from strands.hooks import BeforeNodeCallEvent, HookProvider +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookProvider @pytest.fixture def interrupt_hook(): class Hook(HookProvider): + def __init__(self): + self.after_count = 0 + def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) + registry.add_callback(AfterNodeCallEvent, self.cleanup) def interrupt(self, event): return event.interrupt("test_name", reason="test_reason") + def cleanup(self, event): + self.after_count += 1 + return Hook() diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index cd750865e..75482939d 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2126,6 +2126,10 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): ] assert tru_interrupts == exp_interrupts + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + interrupt = multiagent_result.interrupts[0] responses = [ { @@ -2152,4 +2156,8 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + assert multiagent_result.execution_time >= first_execution_time diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 75ef97a25..491adc7c3 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1259,6 +1259,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): ] assert tru_interrupts == exp_interrupts + tru_after_count = interrupt_hook.after_count + exp_after_count = 0 + assert tru_after_count == exp_after_count + interrupt = multiagent_result.interrupts[0] responses = [ { @@ -1281,6 +1285,10 @@ def test_swarm_interrupt_on_before_node_call_event(interrupt_hook): exp_message = "Task completed" assert tru_message == exp_message + tru_after_count = interrupt_hook.after_count + exp_after_count = 1 + assert tru_after_count == exp_after_count + assert multiagent_result.execution_time >= first_execution_time From 66d3db25a385a327512d46e385fdba026e58cc5e Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Thu, 22 Jan 2026 19:06:18 +0200 Subject: [PATCH 083/279] ci: add workflow for lambda layer publish (#870) --- .github/workflows/LAMDBA_LAYERS_SOP.md | 31 ++++ .github/workflows/publish-lambda-layer.yml | 167 +++++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 .github/workflows/LAMDBA_LAYERS_SOP.md create mode 100644 .github/workflows/publish-lambda-layer.yml diff --git a/.github/workflows/LAMDBA_LAYERS_SOP.md b/.github/workflows/LAMDBA_LAYERS_SOP.md new file mode 100644 index 000000000..1cf58a614 --- /dev/null +++ b/.github/workflows/LAMDBA_LAYERS_SOP.md @@ -0,0 +1,31 @@ +# Lambda Layers Standard Operating Procedures (SOP) + +## Overview + +This document defines the standard operating procedures for managing Strands Agents Lambda layers across all AWS regions, Python versions, and architectures. + +**Total: 136 individual Lambda layers** (17 regions × 2 architectures × 4 Python versions). All variants must maintain the same layer version number for each PyPI package version, with only one row per PyPI version appearing in documentation. + +## Deployment Process + +### 1. Initial Deployment +1. Run workflow with ALL options selected (default) +2. Specify PyPI package version +3. Type "Create Lambda Layer {package_version}" to confirm +4. All 136 individual layers deploy in parallel (4 Python × 2 arch × 17 regions) +5. Each layer gets its own unique name: `strands-agents-py{PYTHON_VERSION}-{ARCH}` + +### 2. Version Buffering for New Variants +When adding new variants (new Python version, architecture, or region): + +1. **Determine target layer version**: Check existing variants to find the highest layer version +2. **Buffer deployment**: Deploy new variants multiple times until layer version matches existing variants +3. **Example**: If existing variants are at layer version 5, deploy new variant 5 times to reach version 5 + +### 3. Handling Transient Failures +When some regions fail during deployment: + +1. **Identify failed regions**: Check which combinations didn't complete successfully +2. **Targeted redeployment**: Use specific region/arch/Python inputs to redeploy failed combinations +3. **Version alignment**: Continue deploying until all variants reach the same layer version +4. **Verification**: Confirm all combinations have identical layer versions before updating docs \ No newline at end of file diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml new file mode 100644 index 000000000..4211d715f --- /dev/null +++ b/.github/workflows/publish-lambda-layer.yml @@ -0,0 +1,167 @@ +name: Publish PyPI Package to Lambda Layer + +on: + workflow_dispatch: + inputs: + package_version: + description: 'Package version to download' + required: true + type: string + layer_version: + description: 'Layer version' + required: true + type: string + python_version: + description: 'Python version' + required: true + default: 'ALL' + type: choice + options: ['ALL', '3.10', '3.11', '3.12', '3.13'] + architecture: + description: 'Architecture' + required: true + default: 'ALL' + type: choice + options: ['ALL', 'x86_64', 'aarch64'] + region: + description: 'AWS region' + required: true + default: 'ALL' + type: choice + # Only non opt-in regions included for now + options: ['ALL', 'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', 'ap-south-1', 'ap-northeast-1', 'ap-northeast-2', 'ap-northeast-3', 'ap-southeast-1', 'ap-southeast-2', 'ca-central-1', 'eu-central-1', 'eu-west-1', 'eu-west-2', 'eu-west-3', 'eu-north-1', 'sa-east-1'] + confirm: + description: 'Type "Create Lambda Layer {PyPI version}-layer{layer version}" to confirm publishing the layer' + required: true + type: string + +jobs: + validate: + runs-on: ubuntu-latest + steps: + - name: Validate confirmation + run: | + CONFIRM="${{ inputs.confirm }}" + EXPECTED="Create Lambda Layer ${{ inputs.package_version }}-layer${{ inputs.layer_version }}" + if [ "$CONFIRM" != "$EXPECTED" ]; then + echo "Confirmation failed. You must type exactly '$EXPECTED' to proceed." + exit 1 + fi + echo "Confirmation validated" + + package-and-upload: + needs: validate + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Create layer directory structure + run: | + mkdir -p layer/python + + - name: Download and install package + run: | + pip install strands-agents==${{ inputs.package_version }} \ + --python-version ${{ matrix.python-version }} \ + --platform manylinux2014_${{ matrix.architecture }} \ + -t layer/python/ \ + --only-binary=:all: + + - name: Create layer zip + run: | + cd layer + zip -r ../lambda-layer.zip . + + - name: Upload to S3 + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + BUCKET_NAME="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + aws s3 cp lambda-layer.zip "s3://$BUCKET_NAME/$LAYER_KEY" --region "$REGION" + echo "Uploaded layer to s3://$BUCKET_NAME/$LAYER_KEY" + + publish-layer: + needs: package-and-upload + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ${{ inputs.python_version == 'ALL' && fromJson('["3.10", "3.11", "3.12", "3.13"]') || fromJson(format('["{0}"]', inputs.python_version)) }} + architecture: ${{ inputs.architecture == 'ALL' && fromJson('["x86_64", "aarch64"]') || fromJson(format('["{0}"]', inputs.architecture)) }} + region: ${{ inputs.region == 'ALL' && fromJson('["us-east-1", "us-east-2", "us-west-1", "us-west-2", "ap-south-1", "ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-west-1", "eu-west-2", "eu-west-3", "eu-north-1", "sa-east-1"]') || fromJson(format('["{0}"]', inputs.region)) }} + + permissions: + id-token: write + + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} + aws-region: ${{ matrix.region }} + + - name: Publish layer + run: | + PYTHON_VERSION="${{ matrix.python-version }}" + ARCH="${{ matrix.architecture }}" + REGION="${{ matrix.region }}" + LAYER_NAME="strands-agents-py${PYTHON_VERSION//./_}-${ARCH}" + ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) + REGION_BUCKET="strands-layer-${ACCOUNT_ID}-${{ secrets.STRANDS_LAMBDA_LAYER_BUCKET_SALT }}-${REGION}" + LAYER_KEY="$LAYER_NAME/${{ inputs.package_version }}/layer${{ inputs.layer_version }}/lambda-layer.zip" + + DESCRIPTION="PyPI package: strands-agents v${{ inputs.package_version }} (Python $PYTHON_VERSION, $ARCH)" + + # Set compatible architecture based on matrix architecture + if [ "$ARCH" = "x86_64" ]; then + COMPATIBLE_ARCH="x86_64" + else + COMPATIBLE_ARCH="arm64" + fi + + LAYER_OUTPUT=$(aws lambda publish-layer-version \ + --layer-name $LAYER_NAME \ + --description "$DESCRIPTION" \ + --content S3Bucket=$REGION_BUCKET,S3Key=$LAYER_KEY \ + --compatible-runtimes python${{ matrix.python-version }} \ + --compatible-architectures $COMPATIBLE_ARCH \ + --region "$REGION" \ + --license-info Apache-2.0 \ + --output json) + + LAYER_ARN=$(echo "$LAYER_OUTPUT" | jq -r '.LayerArn') + LAYER_VERSION=$(echo "$LAYER_OUTPUT" | jq -r '.Version') + + echo "Published layer version $LAYER_VERSION with ARN: $LAYER_ARN in region $REGION" + + aws lambda add-layer-version-permission \ + --layer-name $LAYER_NAME \ + --version-number $LAYER_VERSION \ + --statement-id public \ + --action lambda:GetLayerVersion \ + --principal '*' \ + --region "$REGION" + + echo "Successfully published layer version $LAYER_VERSION in region $REGION" \ No newline at end of file From 612b07eee5140165a1d72fb0247d3c94d965252a Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Thu, 22 Jan 2026 09:27:21 -0800 Subject: [PATCH 084/279] fix: Populate tool_args correctly for steering (#1531) --- .../context_providers/ledger_provider.py | 2 +- .../context_providers/test_ledger_provider.py | 4 +- .../steering/core/test_handler.py | 2 +- tests_integ/steering/test_model_steering.py | 12 ++++- tests_integ/steering/test_tool_steering.py | 44 +++++++++++++++++++ 5 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py index da8504bd0..0e7bde529 100644 --- a/src/strands/experimental/steering/context_providers/ledger_provider.py +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -47,7 +47,7 @@ def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext tool_call_entry = { "timestamp": datetime.now().isoformat(), "tool_name": event.tool_use.get("name"), - "tool_args": event.tool_use.get("arguments", {}), + "tool_args": event.tool_use.get("input", {}), "status": "pending", } ledger["tool_calls"].append(tool_call_entry) diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py index 4356b3ea8..1d280f7c1 100644 --- a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py +++ b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py @@ -30,7 +30,7 @@ def test_ledger_before_tool_call_new_ledger(mock_datetime): callback = LedgerBeforeToolCall() steering_context = SteeringContext() - tool_use = {"name": "test_tool", "arguments": {"param": "value"}} + tool_use = {"name": "test_tool", "input": {"param": "value"}} event = Mock(spec=BeforeToolCallEvent) event.tool_use = tool_use @@ -65,7 +65,7 @@ def test_ledger_before_tool_call_existing_ledger(mock_datetime): } steering_context.data.set("ledger", existing_ledger) - tool_use = {"name": "new_tool", "arguments": {"param": "value"}} + tool_use = {"name": "new_tool", "input": {"param": "value"}} event = Mock(spec=BeforeToolCallEvent) event.tool_use = tool_use diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index a16208e5b..cbe2b3783 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -241,7 +241,7 @@ def test_context_callbacks_receive_steering_context(): # Create a mock event and call the callback event = Mock(spec=BeforeToolCallEvent) - event.tool_use = {"name": "test_tool", "arguments": {}} + event.tool_use = {"name": "test_tool", "input": {}} # The callback should execute without error and update the steering context before_callback(event) diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py index e867ea033..dccb0fa3a 100644 --- a/tests_integ/steering/test_model_steering.py +++ b/tests_integ/steering/test_model_steering.py @@ -1,6 +1,7 @@ """Integration tests for model steering (steer_after_model).""" from strands import Agent, tool +from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed from strands.experimental.steering.core.handler import SteeringHandler from strands.types.content import Message @@ -154,7 +155,7 @@ class ForceToolUsageHandler(SteeringHandler): """Handler that forces a specific tool to be used before allowing termination.""" def __init__(self, required_tool: str): - super().__init__() + super().__init__(context_providers=[LedgerProvider()]) self.required_tool = required_tool self.tool_was_used = False self.guidance_given = False @@ -171,6 +172,15 @@ async def steer_after_model( for block in content_blocks: if "toolUse" in block and block["toolUse"].get("name") == self.required_tool: self.tool_was_used = True + + # Verify tool is in the ledger + ledger = self.steering_context.data.get("ledger") + if ledger: + tool_calls = ledger.get("tool_calls", []) + assert any(tc.get("tool_name") == self.required_tool for tc in tool_calls), ( + f"{self.required_tool} should be in ledger when tool_was_used=True" + ) + return Proceed(reason="Required tool was used") # If tool wasn't used and we haven't guided yet, force its usage diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py index eced94ba0..75073c648 100644 --- a/tests_integ/steering/test_tool_steering.py +++ b/tests_integ/steering/test_tool_steering.py @@ -3,7 +3,9 @@ import pytest from strands import Agent, tool +from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider from strands.experimental.steering.core.action import Guide, Interrupt, Proceed +from strands.experimental.steering.core.handler import SteeringHandler from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler @@ -98,3 +100,45 @@ def test_agent_with_tool_steering_e2e(): notification_metrics = tool_metrics["send_notification"] assert notification_metrics.call_count >= 1, "send_notification should have been called" assert notification_metrics.success_count >= 1, "send_notification should have succeeded" + + +def test_ledger_captures_tool_calls(): + """Test that ledger correctly captures tool call information.""" + + class LedgerCheckingHandler(SteeringHandler): + def __init__(self): + super().__init__(context_providers=[LedgerProvider()]) + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + ledger = self.steering_context.data.get("ledger") + assert ledger is not None, "Ledger should exist" + assert "tool_calls" in ledger, "Ledger should have tool_calls" + + # Find the current tool call in the ledger + tool_calls = ledger["tool_calls"] + current_call = next((tc for tc in tool_calls if tc["tool_name"] == tool_use["name"]), None) + assert current_call is not None, f"{tool_use['name']} should be in ledger" + assert current_call["tool_args"] == tool_use["input"], "tool_args should match input" + assert current_call["status"] == "pending", "Status should be pending before execution" + + return Proceed(reason="Ledger verified") + + handler = LedgerCheckingHandler() + agent = Agent(tools=[send_notification], hooks=[handler]) + + agent("Send a notification to alice saying test message") + + # Verify the ledger has the completed tool call + ledger = handler.steering_context.data.get("ledger") + assert ledger is not None + assert len(ledger["tool_calls"]) >= 1, "At least one tool call should be recorded" + + # Check the tool call details + tool_call = ledger["tool_calls"][-1] + assert tool_call["tool_name"] == "send_notification" + assert "tool_args" in tool_call + assert tool_call["tool_args"]["recipient"] == "alice" + assert tool_call["tool_args"]["message"] == "test message" + assert tool_call["status"] == "success" + assert "completion_timestamp" in tool_call + assert tool_call["error"] is None From fa864440847a79f3c9c42b938a8d624fa8e47bef Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 23 Jan 2026 09:53:23 -0500 Subject: [PATCH 085/279] interrupts - graph - agent based (#1533) Co-authored-by: Mohammad Salehan --- src/strands/multiagent/graph.py | 46 +++++-- tests/strands/multiagent/test_graph.py | 98 +++++++++++++ .../interrupts/multiagent/test_agent.py | 129 +++++++++++++++++- .../interrupts/multiagent/test_session.py | 49 ++----- 4 files changed, 266 insertions(+), 56 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index bad7eede9..d296753c0 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -622,6 +622,13 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() + if isinstance(node.executor, Agent): + self._interrupt_state.context[node.node_id] = { + "activated": node.executor._interrupt_state.activated, + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } return MultiAgentNodeInterruptEvent(node.node_id, interrupts) @@ -920,16 +927,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from an agent node" - ) - # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) usage = getattr( @@ -940,18 +937,24 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) node_result = NodeResult( result=agent_response, execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, + status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, + interrupts=agent_response.interrupts or [], ) else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - # Mark as completed - node.execution_status = Status.COMPLETED node.result = node_result node.execution_time = node_result.execution_time + + if node_result.status == Status.INTERRUPTED: + yield self._activate_interrupt(node, node_result.interrupts) + return + + # Mark as completed + node.execution_status = Status.COMPLETED self.state.completed_nodes.add(node) self.state.results[node.node_id] = node_result self.state.execution_order.append(node) @@ -1019,6 +1022,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """Build input text for a node based on dependency outputs. + If resuming from an interrupt, return user responses. + Example formatted output: ``` Original Task: Analyze the quarterly sales data and create a summary report @@ -1033,6 +1038,21 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: - Agent: Data validation complete. All records verified, no anomalies detected. ``` """ + if self._interrupt_state.activated: + context = self._interrupt_state.context + if node.node_id in context and context[node.node_id]["activated"]: + agent_context = context[node.node_id] + agent = cast(Agent, node.executor) + agent.messages = agent_context["messages"] + agent.state = AgentState(agent_context["state"]) + agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"]) + + responses = context["responses"] + interrupts = agent._interrupt_state.interrupts + return [ + response for response in responses if response["interruptResponse"]["interruptId"] in interrupts + ] + # Get satisfied dependencies dependency_results = {} for edge in self.edges: diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 75482939d..c511328d4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -23,6 +23,9 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent.state = AgentState() + agent.messages = [] + agent._interrupt_state = _InterruptState() if metrics is None: metrics = Mock( @@ -2161,3 +2164,98 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook): assert tru_after_count == exp_after_count assert multiagent_result.execution_time >= first_execution_time + + +def test_graph_interrupt_on_agent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + agent = create_mock_agent("test_agent", "Task completed") + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="interrupt", + state={}, + metrics=None, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + agent.stream_async = Mock() + agent.stream_async.return_value = agenerator( + [ + { + "result": AgentResult( + message={}, + stop_reason="end_turn", + state={}, + metrics=None, + ), + }, + ], + ) + graph._interrupt_state.context["test_agent"] = { + "activated": True, + "interrupt_state": { + "activated": True, + "context": {}, + "interrupts": {interrupt.id: interrupt.to_dict()}, + }, + "messages": [], + "state": {}, + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + agent.stream_async.assert_called_once_with(responses, invocation_state={}) diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_agent.py index 36fcfef27..1a6ad87c6 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_agent.py @@ -5,28 +5,83 @@ from strands import Agent, tool from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.types.tools import ToolContext +@pytest.fixture +def day_tool(): + @tool(name="day_tool", context=True) + def func(tool_context: ToolContext) -> str: + response = tool_context.interrupt("day_interrupt", reason="need day") + return response + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) def func(tool_context: ToolContext) -> str: - response = tool_context.interrupt("test_interrupt", reason="need weather") + response = tool_context.interrupt("weather_interrupt", reason="need weather") return response return func @pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + +@pytest.fixture +def swarm(weather_agent): return Swarm([weather_agent]) +@pytest.fixture +def graph(info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + + return builder.build() + + def test_swarm_interrupt_agent(swarm): multiagent_result = swarm("What is the weather?") @@ -38,7 +93,7 @@ def test_swarm_interrupt_agent(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need weather", ), ] @@ -65,3 +120,67 @@ def test_swarm_interrupt_agent(swarm): weather_message = json.dumps(weather_result.result.message).lower() assert "sunny" in weather_message + + +def test_graph_interrupt_agent(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need day", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need weather", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "monday", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "sunny", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index 2ccff2c12..96b9844bf 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,7 +4,6 @@ import pytest from strands import Agent, tool -from strands.hooks import BeforeNodeCallEvent, HookProvider from strands.interrupt import Interrupt from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -12,21 +11,6 @@ from strands.types.tools import ToolContext -@pytest.fixture -def interrupt_hook(): - class Hook(HookProvider): - def register_hooks(self, registry): - registry.add_callback(BeforeNodeCallEvent, self.interrupt) - - def interrupt(self, event): - if event.node_id == "time": - response = event.interrupt("test_interrupt", reason="need approval") - if response != "APPROVE": - event.cancel_node = "node rejected" - - return Hook() - - @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) @@ -37,15 +21,6 @@ def func(tool_context: ToolContext) -> str: return func -@pytest.fixture -def time_tool(): - @tool(name="time_tool") - def func(): - return "12:01" - - return func - - def test_swarm_interrupt_session(weather_tool, tmpdir): weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") @@ -96,20 +71,19 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): assert "sunny" in summarizer_message -def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): - time_agent = Agent(name="time", tools=[time_tool]) +def test_graph_interrupt_session(weather_tool, tmpdir): + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) builder = GraphBuilder() - builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") builder.add_node(summarizer_agent, "summarizer") - builder.add_edge("time", "summarizer") - builder.set_hook_providers([interrupt_hook]) + builder.add_edge("weather", "summarizer") builder.set_session_manager(session_manager) graph = builder.build() - multiagent_result = graph("Can you check the time and then summarize the results?") + multiagent_result = graph("Can you check the weather and then summarize the results?") tru_result_status = multiagent_result.status exp_result_status = Status.INTERRUPTED @@ -124,22 +98,21 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): Interrupt( id=ANY, name="test_interrupt", - reason="need approval", + reason="need weather", ), ] assert tru_interrupts == exp_interrupts interrupt = multiagent_result.interrupts[0] - time_agent = Agent(name="time", tools=[time_tool]) + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) builder = GraphBuilder() - builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") builder.add_node(summarizer_agent, "summarizer") - builder.add_edge("time", "summarizer") - builder.set_hook_providers([interrupt_hook]) + builder.add_edge("weather", "summarizer") builder.set_session_manager(session_manager) graph = builder.build() @@ -147,7 +120,7 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): { "interruptResponse": { "interruptId": interrupt.id, - "response": "APPROVE", + "response": "sunny", }, }, ] @@ -163,4 +136,4 @@ def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): assert len(multiagent_result.results) == 2 summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() - assert "12:01" in summarizer_message + assert "sunny" in summarizer_message From fdd9482e128ada3e98117f3b03efbca8c2f3cb1f Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 23 Jan 2026 15:43:29 -0500 Subject: [PATCH 086/279] chore: refactor use_span to be closed automatically (#1293) --- src/strands/event_loop/event_loop.py | 195 +++++++++----------- src/strands/telemetry/tracer.py | 31 ++-- tests/strands/event_loop/test_event_loop.py | 5 - tests/strands/telemetry/test_tracer.py | 11 -- 4 files changed, 111 insertions(+), 131 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f5d00a201..41122efc5 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -139,108 +139,98 @@ async def event_loop_cycle( ) invocation_state["event_loop_cycle_span"] = cycle_span - # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. - if agent._interrupt_state.activated: - stop_reason: StopReason = "tool_use" - message = agent._interrupt_state.context["tool_use_message"] - # Skip model invocation if the latest message contains ToolUse - elif _has_tool_use_in_latest_message(agent.messages): - stop_reason = "tool_use" - message = agent.messages[-1] - else: - model_events = _handle_model_execution( - agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context - ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event - - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + with trace_api.use_span(cycle_span, end_on_exit=True): + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + + try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) - try: - if stop_reason == "max_tokens": - """ - Handle max_tokens limit reached by the model. - - When the model reaches its maximum token limit, this represents a potentially unrecoverable - state where the model's response was truncated. By default, Strands fails hard with an - MaxTokensReachedException to maintain consistency with other failure types. - """ - raise MaxTokensReachedException( - message=( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + if stop_reason == "tool_use": + # Handle tool execution + tool_events = _handle_tool_execution( + stop_reason, + message, + agent=agent, + cycle_trace=cycle_trace, + cycle_span=cycle_span, + cycle_start_time=cycle_start_time, + invocation_state=invocation_state, + tracer=tracer, + structured_output_context=structured_output_context, + ) + async for tool_event in tool_events: + yield tool_event + + return + + # End the cycle and return results + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) + # Set attributes before span auto-closes + tracer.end_event_loop_cycle_span(cycle_span, message) + except EventLoopException: + # Don't yield or log the exception - we already did it when we + # raised the exception and we don't need that duplication. + raise + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException + raise e + except Exception as e: + # Handle any other exceptions + yield ForceStopEvent(reason=e) + logger.exception("cycle failed") + raise EventLoopException(e, invocation_state["request_state"]) from e + + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + await agent._append_messages( + {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} ) - if stop_reason == "tool_use": - # Handle tool execution - tool_events = _handle_tool_execution( - stop_reason, - message, - agent=agent, - cycle_trace=cycle_trace, - cycle_span=cycle_span, - cycle_start_time=cycle_start_time, - invocation_state=invocation_state, - tracer=tracer, - structured_output_context=structured_output_context, + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context ) - async for tool_event in tool_events: - yield tool_event - + async for typed_event in events: + yield typed_event return - # End the cycle and return results - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, - message=message, - ) - except EventLoopException as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - - # Don't yield or log the exception - we already did it when we - # raised the exception and we don't need that duplication. - raise - except (ContextWindowOverflowException, MaxTokensReachedException) as e: - # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - raise e - except Exception as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - - # Handle any other exceptions - yield ForceStopEvent(reason=e) - logger.exception("cycle failed") - raise EventLoopException(e, invocation_state["request_state"]) from e - - # Force structured output tool call if LLM didn't use it automatically - if structured_output_context.is_enabled and stop_reason == "end_turn": - if structured_output_context.force_attempted: - raise StructuredOutputException( - "The model failed to invoke the structured output tool even after it was forced." - ) - structured_output_context.set_forced_mode() - logger.debug("Forcing structured output tool") - await agent._append_messages( - {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} - ) - - events = recurse_event_loop( - agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context - ) - async for typed_event in events: - yield typed_event - return - - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) async def recurse_event_loop( @@ -324,7 +314,7 @@ async def _handle_model_execution( model_id=model_id, custom_trace_attributes=agent.trace_attributes, ) - with trace_api.use_span(model_invoke_span): + with trace_api.use_span(model_invoke_span, end_on_exit=True): await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, @@ -372,14 +362,12 @@ async def _handle_model_execution( if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) - if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) + # Set attributes before span auto-closes + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - + # Exception is automatically recorded by use_span with end_on_exit=True after_model_call_event = AfterModelCallEvent( agent=agent, exception=e, @@ -422,9 +410,6 @@ async def _handle_model_execution( agent.event_loop_metrics.update_metrics(metrics) except Exception as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e @@ -508,6 +493,7 @@ async def _handle_tool_execution( interrupts, structured_output=structured_output_result, ) + # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -525,6 +511,7 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) + # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d73ea3c39..6ab33301a 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -316,18 +316,22 @@ def end_model_invoke_span( usage: Usage, metrics: Metrics, stop_reason: StopReason, - error: Exception | None = None, ) -> None: """End a model invocation span with results and metrics. + Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. + Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. + Args: - span: The span to end. + span: The span to set attributes on. message: The message response from the model. usage: Token usage information from the model call. metrics: Metrics from the model call. - stop_reason (StopReason): The reason the model stopped generating. - error: Optional exception if the model call failed. + stop_reason: The reason the model stopped generating. """ + # Set end time attribute + span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) + attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], @@ -362,7 +366,7 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - self._end_span(span, attributes, error) + self._set_attributes(span, attributes) def start_tool_call_span( self, @@ -492,7 +496,7 @@ def start_event_loop_cycle_span( parent_span: Span | None = None, custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, - ) -> Span | None: + ) -> Span: """Start a new span for an event loop cycle. Args: @@ -532,17 +536,23 @@ def end_event_loop_cycle_span( span: Span, message: Message, tool_result_message: Message | None = None, - error: Exception | None = None, ) -> None: """End an event loop cycle span with results. + Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. + Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. + Args: - span: The span to end. + span: The span to set attributes on. message: The message response from this cycle. tool_result_message: Optional tool result message if a tool was called. - error: Optional exception if the cycle failed. """ - attributes: dict[str, AttributeValue] = {} + if not span: + return + + # Set end time attribute + span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) + event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: @@ -565,7 +575,6 @@ def end_event_loop_cycle_span( ) else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) - self._end_span(span, attributes, error) def start_agent_span( self, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index d4afd579b..a76a5b6b5 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -576,9 +576,6 @@ async def test_event_loop_tracing_with_model_error( ) await alist(stream) - # Verify error handling span methods were called - mock_tracer.end_span_with_error.assert_called_once_with(model_span, "Input too long", model.stream.side_effect) - @pytest.mark.asyncio async def test_event_loop_cycle_max_tokens_exception( @@ -705,8 +702,6 @@ async def test_event_loop_tracing_with_throttling_exception( ) await alist(stream) - # Verify error span was created for the throttling exception - assert mock_tracer.end_span_with_error.call_count == 1 # Verify span was created for the successful retry assert mock_tracer.start_model_invoke_span.call_count == 2 assert mock_tracer.end_model_invoke_span.call_count == 1 diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index cb98b8130..6ea605083 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -246,8 +246,6 @@ def test_end_model_invoke_span(mock_span): "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, ) - mock_span.set_status.assert_called_once_with(StatusCode.OK) - mock_span.end.assert_called_once() def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): @@ -284,9 +282,6 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): }, ) - mock_span.set_status.assert_called_once_with(StatusCode.OK) - mock_span.end.assert_called_once() - def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" @@ -650,8 +645,6 @@ def test_end_event_loop_cycle_span(mock_span): "tool.result": json.dumps(tool_result_message["content"]), }, ) - mock_span.set_status.assert_called_once_with(StatusCode.OK) - mock_span.end.assert_called_once() def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): @@ -687,8 +680,6 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): ) }, ) - mock_span.set_status.assert_called_once_with(StatusCode.OK) - mock_span.end.assert_called_once() def test_start_agent_span(mock_tracer): @@ -890,8 +881,6 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 10) mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 5) - mock_span.set_status.assert_called_once_with(StatusCode.OK) - mock_span.end.assert_called_once() def test_end_agent_span_with_cache_metrics(mock_span): From 1cedaed15c6bb8346b0687c459e1ed4ba3046d35 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 23 Jan 2026 23:19:11 +0200 Subject: [PATCH 087/279] ci: limit permission scope on lambda layer github action (#1555) --- .github/workflows/publish-lambda-layer.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index 4211d715f..3ad9e9abf 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -38,6 +38,7 @@ on: jobs: validate: runs-on: ubuntu-latest + permissions: {} steps: - name: Validate confirmation run: | From 98fcc2c7546163305810f152a7b2d509c091bc6d Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Mon, 26 Jan 2026 11:13:05 -0500 Subject: [PATCH 088/279] chore: Enable Auto-close labels on Pull requests as well. (#1552) --- .github/workflows/auto-close.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/auto-close.yml b/.github/workflows/auto-close.yml index dc9b577a0..be31606d9 100644 --- a/.github/workflows/auto-close.yml +++ b/.github/workflows/auto-close.yml @@ -24,13 +24,13 @@ jobs: include: - label: 'autoclose in 3 days' days: 3 - issue_types: 'issues' #issues/pulls/both + issue_types: 'both' #issues/pulls/both replacement_label: '' closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 3 days.' dry_run: 'false' - label: 'autoclose in 7 days' days: 7 - issue_types: 'issues' # issues/pulls/both + issue_types: 'both' # issues/pulls/both replacement_label: '' closure_message: 'This issue has been automatically closed as it was marked for auto-closure by the team and no additional responses was received within 7 days.' dry_run: 'false' From ee319471690a4faa1ec165029090bc4ec7216f32 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 27 Jan 2026 10:28:30 -0500 Subject: [PATCH 089/279] Use devtools actions (#1554) --- .github/actions/README.md | 285 ------ .../actions/strands-agent-runner/action.yml | 179 ---- .../actions/strands-write-executor/action.yml | 147 --- .github/agent-sops/task-implementer.sop.md | 493 ---------- .github/agent-sops/task-refiner.sop.md | 298 ------- .github/agent-sops/task-release-notes.sop.md | 772 ---------------- .github/scripts/javascript/process-input.cjs | 141 --- .github/scripts/python/agent_runner.py | 163 ---- .github/scripts/python/github_tools.py | 843 ------------------ .github/scripts/python/handoff_to_user.py | 34 - .github/scripts/python/notebook.py | 337 ------- .github/scripts/python/requirements.txt | 8 - .../python/str_replace_based_edit_tool.py | 230 ----- .github/scripts/python/write_executor.py | 152 ---- .github/workflows/integration-test.yml | 38 +- .github/workflows/strands-command.yml | 141 +-- 16 files changed, 36 insertions(+), 4225 deletions(-) delete mode 100644 .github/actions/README.md delete mode 100644 .github/actions/strands-agent-runner/action.yml delete mode 100644 .github/actions/strands-write-executor/action.yml delete mode 100644 .github/agent-sops/task-implementer.sop.md delete mode 100644 .github/agent-sops/task-refiner.sop.md delete mode 100644 .github/agent-sops/task-release-notes.sop.md delete mode 100644 .github/scripts/javascript/process-input.cjs delete mode 100644 .github/scripts/python/agent_runner.py delete mode 100644 .github/scripts/python/github_tools.py delete mode 100644 .github/scripts/python/handoff_to_user.py delete mode 100644 .github/scripts/python/notebook.py delete mode 100644 .github/scripts/python/requirements.txt delete mode 100644 .github/scripts/python/str_replace_based_edit_tool.py delete mode 100755 .github/scripts/python/write_executor.py diff --git a/.github/actions/README.md b/.github/actions/README.md deleted file mode 100644 index 6559462cb..000000000 --- a/.github/actions/README.md +++ /dev/null @@ -1,285 +0,0 @@ -# Strands Command GitHub Actions - -A comprehensive AI agent execution system for GitHub repositories that processes `/strands` commands in issues and pull requests. - -## Overview - -The Strands Command system enables AI-powered automation in GitHub repositories through: - -- **Issue Comment Processing**: Responds to `/strands` commands in issues and PRs -- **Controlled AI Execution**: Runs AI agents with read-only and write-separated permissions -- **AWS Integration**: Secure OIDC-based authentication with Bedrock AI models -- **Security-First Design**: Manual approval gates and permission isolation - -### Architecture - -```mermaid -graph LR - A["strands Command"] --> B[Authorization] - B --> C[Read-Only Agent] - C --> D[Write Operations] - D --> E[Cleanup] - - B -.-> B1[Permission Check] - C -.-> C1[AWS + AI Execution] - D -.-> D1[Repository Updates] -``` - -## Quick Start - -1. **Set up AWS IAM Role** (see [IAM Role Policy](#iam-role-policy)) -2. **Configure GitHub Secrets**: - - `AWS_ROLE_ARN`: Your IAM role ARN - - `STRANDS_SESSION_BUCKET`: S3 bucket for session storage -3. **Copy required files** to your repository: - - `.github/workflows/strands-command.yml` - - `.github/actions/` directory - - `.github/scripts/` directory - - `.github/agent-sops/` directory -4. **Comment `/strands [your task]`** on any issue or PR - - **On Issues**: - - Use `/strands ` to have an agent help you refine an issue within the context of the current github repo - - Use `/strands implement ` to create a new PR based on the description of an issue - - **On PRs**: `/strands ` will instruct an Agent to review PR comments and make updates to the issue - -## Actions - -### strands-agent-runner - -Executes AI agents with AWS integration and controlled permissions. - -**Inputs:** -- `ref` (required): Git reference to checkout -- `system_prompt` (required): System instructions for the agent -- `session_id` (required): Session identifier for persistence -- `task_prompt` (required): Task description for the agent -- `aws_role_arn` (required): AWS IAM role ARN for authentication -- `sessions_bucket` (required): S3 bucket for session storage -- `write_permission` (required): Permission level flag for Read-only Sandbox mode (`true`/`false`) - -**Features:** -- Strands Agent running with Agent SOPs specifically designed to instruct an Agent on how to develop in Github -- Python 3.13 and Node.js 20 environment setup (Node.js setup and npm install are optional and can be removed - only included for this repo's development) -- Read-only Sandbox support: Agent write actions can be deferred to the `strands-write-executor` action if you want your agent to execute with read-only github permissions - -### strands-write-executor - -Executes write operations from agent-generated artifacts if `strands-agent-runner` was run with `write_permissions: false`. - -**Inputs:** -- `ref` (required): Target branch for changes -- `issue_id` (optional): Associated issue number - -**Features:** -- Reads Agent modified repository state from artifacts, and pushes changes to pr branch -- Reads deferred write operations from artifact and executes them - -## Workflows - -### strands-command.yml - -Main workflow that orchestrates the complete Strands command execution: - -1. **Authorization Check**: Validates user permissions and applies approval gates -2. **Setup and Processing**: Parses input and prepares execution context -3. **Read-Only Execution**: Runs Agent in Read-only sandbox -4. **Write Operations**: Executes repository modifications in job isolated from agent -5. **Cleanup**: Removes temporary labels and artifacts - -**Triggers:** -- Issue comments starting with `/strands` -- Manual workflow dispatch with parameters - -## Agent SOPs - -### Task Implementer (`task-implementer.sop.md`) - -Implements features using test-driven development principles. - -**Workflow**: Setup → Explore → Plan → Code → Commit → Pull Request - -**Capabilities:** -- Feature implementation with TDD approach -- Comprehensive testing and documentation -- Pull request creation and iteration -- Code pattern following and best practices - -### Task Refiner (`task-refiner.sop.md`) - -Refines and clarifies task requirements before implementation. - -**Workflow**: Read Issue → Analyze → Research → Clarify → Iterate - -**Capabilities:** -- Requirement analysis and gap identification -- Clarifying question generation -- Implementation planning and preparation -- Ambiguity resolution through user interaction - -## IAM Role Policy - -### Required IAM Role - -Create an IAM role with the following trust policy for GitHub OIDC: - -```json -{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Principal": { - "Federated": "arn:aws:iam::YOUR_ACCOUNT_ID:oidc-provider/token.actions.githubusercontent.com" - }, - "Action": "sts:AssumeRoleWithWebIdentity", - "Condition": { - "StringEquals": { - "token.actions.githubusercontent.com:aud": "sts.amazonaws.com" - }, - "StringLike": { - "token.actions.githubusercontent.com:sub": "repo:YOUR_ORG/YOUR_REPO:*" - } - } - } - ] -} -``` - -### IAM Role Policy - -Your IAM role must have these permissions in order to execute: - -```json -{ - "Version": "2012-10-17", - "Statement": [ - { - "Sid": "Bedrock Access", - "Effect": "Allow", - "Action": [ - "bedrock:InvokeModelWithResponseStream", - "bedrock:InvokeModel" - ], - "Resource": "*" - }, - { - "Effect": "Allow", - "Action": [ - "s3:PutObject", - "s3:GetObject", - "s3:DeleteObject" - ], - "Resource": [ - "arn:aws:s3:::YOUR_STRANDS_SESSION_BUCKET/*" - ] - }, - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::YOUR_STRANDS_SESSION_BUCKET" - ] - } - ] -} -``` - -### Setup Steps - -1. **Create OIDC Provider** (if not exists): - ```bash - aws iam create-open-id-connect-provider \ - --url https://token.actions.githubusercontent.com \ - --thumbprint-list 6938fd4d98bab03faadb97b34396831e3780aea1 \ - --client-id-list sts.amazonaws.com - ``` - -2. **Create IAM Role** with the trust policy above -3. **Create S3 Bucket** for session storage -4. **Add GitHub Secrets**: - - `AWS_ROLE_ARN`: The created role ARN - - `AGENT_SESSIONS_BUCKET`: The S3 bucket name - -## Security - -### ⚠️ Important Security Considerations - -**This workflow should only be used with trusted sources and should use AWS guardrails to help avoid prompt injection risks.** - -### Security Features - -#### Authorization Controls -- **Collaborator Verification**: Only users with write access get auto-approval -- **Manual Approval Gates**: Unknown users require manual approval via GitHub environments -- **Permission Separation**: Read and write operations isolated in separate jobs - -#### AWS Security -- **OIDC Authentication**: No long-lived credentials stored in GitHub -- **Minimal Permissions**: Inline session policy limits access to required resources only -- **Temporary Credentials**: Each execution gets fresh, time-limited AWS credentials. You can further limit these by updating the `strands-agent-runner` "Configure AWS credentials" step, and set the `role-duration-seconds` value -- **Resource Scoping**: S3 access limited to specific session bucket - -#### Prompt Injection Mitigation -- **Trusted Sources Only**: Implement strict user authorization -- **AWS Guardrails**: Use AWS Bedrock guardrails to filter malicious prompts -- **Input Validation**: Validate and sanitize all user inputs -- **Execution Isolation**: Separate read and write phases prevent unauthorized modifications - -## Configuration - -### GitHub Secrets - -| Secret | Description | Example | -|--------|-------------|---------| -| `AWS_ROLE_ARN` | IAM role for AWS access | `arn:aws:iam::123456789012:role/GitHubActionsRole` | -| `STRANDS_SESSION_BUCKET` | S3 bucket for sessions | `my-strands-sessions-bucket` | - -### Environment Variables - -The actions use these environment variables during execution: - -| Variable | Purpose | Set By | -|----------|---------|--------| -| `GITHUB_WRITE` | Permission level indicator | Action | -| `SESSION_ID` | Agent session identifier | Workflow | -| `S3_SESSION_BUCKET` | Session storage location | Input | -| `STRANDS_TOOL_CONSOLE_MODE` | Tool execution mode | Action | -| `BYPASS_TOOL_CONSENT` | Automated tool approval | Action | - -## Usage Examples - -### Basic Task Implementation - -Comment on an issue: -``` -/strands Implement a new user authentication feature with JWT tokens -``` - -### Task Refinement - -Comment on an issue with unclear requirements: -``` -/strands refine Please help clarify the requirements for this feature -``` - -### Manual Execution - -Use workflow dispatch with: -- **issue_id**: `123` -- **command**: `Implement the requested feature` -- **session_id**: `optional-session-id` - -### Advanced Usage - -``` -/strands implement Create a REST API endpoint for user management with the following requirements: -1. CRUD operations for users -2. JWT authentication -3. Input validation -4. Unit tests with 90% coverage -5. OpenAPI documentation -``` - ---- - -**Note**: This system is designed for trusted environments. Always review security implications before deployment and implement appropriate guardrails for your use case. diff --git a/.github/actions/strands-agent-runner/action.yml b/.github/actions/strands-agent-runner/action.yml deleted file mode 100644 index d0e93effe..000000000 --- a/.github/actions/strands-agent-runner/action.yml +++ /dev/null @@ -1,179 +0,0 @@ -name: 'Strands Agent Runner' -description: 'Execute a Strands agent with the given prompts and configuration' -inputs: - ref: - description: 'ref to checkout' - required: true - system_prompt: - description: 'System prompt for the agent' - required: true - session_id: - description: 'Session ID for the agent execution' - required: true - task_prompt: - description: 'Task prompt for the agent' - required: true - aws_role_arn: - description: 'AWS IAM role ARN for authentication' - required: true - sessions_bucket: - description: 'S3 bucket for session storage' - required: true - write_permission: - description: 'If this action runs with write permission. If this is false, you should run the `strands-write-executor` action after this one with write permission.' - required: true - default: 'false' - -runs: - using: 'composite' - steps: - # Checkout main repo .github directory - - name: Checkout repository - uses: actions/checkout@v5 - with: - sparse-checkout: | - .github - - # Copy the .github directory to the runner temp directory so the branch content cant overwrite the scripts executed here - - name: Copy .github to safe directory - shell: bash - run: | - mkdir -p ${{ runner.temp }}/strands-agent-runner - cp -r .github ${{ runner.temp }}/strands-agent-runner - - # Checkout the branch repo to stage the directory for the agent - - name: Checkout repository - uses: actions/checkout@v5 - with: - ref: ${{ inputs.ref }} - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: '20' - - - name: Install dependencies - # If we have package.json then install the dependencies - this is for compatibility in multiple repos - if: hashFiles('package.json') != '' - shell: bash - run: npm install - continue-on-error: true # This step's failure will not stop the workflow - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.13' - - - name: Install uv - uses: astral-sh/setup-uv@v3 - with: - enable-cache: true - cache-dependency-glob: '${{ runner.temp }}/strands-agent-runner/.github/scripts/python/requirements.txt' - - - name: Install Strands Agents - shell: bash - run: | - echo "📦 Installing from requirements.txt" - uv pip install --system -r ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/requirements.txt --quiet - - - name: Configure Git - shell: bash - run: | - git config --global user.name "Strands Agent" - git config --global user.email "217235299+strands-agent@users.noreply.github.com" - git config --global core.pager cat - PAGER=cat - - - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: ${{ inputs.aws_role_arn }} - role-session-name: GitHubActions-StrandsAgent-${{ github.run_id }} - aws-region: us-west-2 - mask-aws-account-id: true - inline_session_policy: >- - { - "Version": "2012-10-17", - "Statement": [ - { - "Sid":"Bedrock Access", - "Effect": "Allow", - "Action": [ - "bedrock:InvokeModelWithResponseStream", - "bedrock:InvokeModel" - ], - "Resource": "*" - }, { - "Effect": "Allow", - "Action": [ - "s3:PutObject", - "s3:GetObject", - "s3:DeleteObject", - ], - "Resource": [ - "arn:aws:s3:::strands-typescript-project-sessions/*", - ] - }, { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::strands-typescript-project-sessions", - ] - } - ] - } - - - - name: Execute strands command - shell: bash - env: - # Write Permission - GITHUB_WRITE: ${{ inputs.write_permission }} - - # GitHub Configuration - GITHUB_TOKEN: ${{ github.token }} - GITHUB_REPOSITORY: ${{ github.repository }} - - # Task Configuration - INPUT_TASK: ${{ inputs.task_prompt }} - INPUT_SYSTEM_PROMPT: ${{ inputs.system_prompt }} - - # AWS Configuration - AWS_REGION: 'us-west-2' - - # Session Manager - S3_SESSION_BUCKET: ${{ inputs.sessions_bucket }} - SESSION_ID: ${{ inputs.session_id }} - - # Strands Env Vars - STRANDS_TOOL_CONSOLE_MODE: 'enabled' - BYPASS_TOOL_CONSENT: 'true' - run: | - uv run --no-project ${{ runner.temp }}/strands-agent-runner/.github/scripts/python/agent_runner.py - - - name: Capture repository state - shell: bash - run: | - mkdir -p .artifact - if git diff --quiet HEAD@{upstream} && git diff --quiet --cached; then - echo "📭 No changes to capture" - else - echo "📝 Capturing entire repository state" - tar -czf .artifact/repository_state.tar.gz --exclude='.artifact' . - fi - - - name: Upload repository state artifact - uses: actions/upload-artifact@v4 - with: - name: repository-state - path: .artifact/repository_state.tar.gz - retention-days: 1 - if-no-files-found: ignore - - - name: Upload artifact for write operations - uses: actions/upload-artifact@v4 - with: - name: write-operations - path: .artifact/write_operations.jsonl - retention-days: 1 - if-no-files-found: ignore \ No newline at end of file diff --git a/.github/actions/strands-write-executor/action.yml b/.github/actions/strands-write-executor/action.yml deleted file mode 100644 index 3417c3140..000000000 --- a/.github/actions/strands-write-executor/action.yml +++ /dev/null @@ -1,147 +0,0 @@ -name: 'Strands Write Executor' -description: 'Execute write GitHub operations from JSONL artifact files during workflow execution' -inputs: - ref: - description: 'Ref to push changes to' - required: true - issue_id: - description: 'Issue ID for fallback operations' - required: false - -runs: - using: 'composite' - steps: - - # Push code changes before running write commands in case we need to create a pull request - # Pull requests cannot be created if a branch has no diff with main, so push changes first, then create pr - - name: Log if ref equals main - shell: bash - run: | - if [ "${{ inputs.ref }}" = "${{ github.event.repository.default_branch }}" ]; then - echo "🚫 Ref is default - skipping push operations to prevent direct commits to default branch" - else - echo "✅ Ref is '${{ inputs.ref }}' - push operations will proceed" - fi - - - name: Download repository state artifact - if: inputs.ref != github.event.repository.default_branch - uses: actions/download-artifact@v4 - with: - name: repository-state - path: ${{ runner.temp }} - continue-on-error: true - - - name: Apply Artifact and Push changes - if: inputs.ref != github.event.repository.default_branch - shell: bash - env: - GITHUB_TOKEN: ${{ github.token }} - run: | - - if [ -f "$RUNNER_TEMP/repository_state.tar.gz" ]; then - echo "📝 Applying repository state" - mkdir -p "$RUNNER_TEMP/temp_git_repo" - tar -xzf "$RUNNER_TEMP/repository_state.tar.gz" -C "$RUNNER_TEMP/temp_git_repo" - rm "$RUNNER_TEMP/repository_state.tar.gz" - - echo "📁 Changing to repository directory" - ORIGINAL_DIRECTORY=$(pwd) - cd "$RUNNER_TEMP/temp_git_repo" - - # Configure Git - git config --local user.name "Strands Agent" - git config --local user.email "217235299+strands-agent@users.noreply.github.com" - git config --local core.pager cat - # We need to overwrite this since this is currently set by the previous readonly workflow artifact - # Overwrite this value with the current token that allows us to push the commit - git config --local http."https://github.com/".extraheader "AUTHORIZATION: basic $(echo -n x-access-token:${{ github.token }}| base64)" - - # Fetch the remote repository - git fetch origin ${{ inputs.ref }} - - # Stage and commit any changes first - if [ -n "$(git status --porcelain)" ]; then - echo "📝 Changes detected, staging all files" - git add -A - echo "📝 Committing changes" - git commit -m "Additional changes from write operations" -n - fi - - # Push if there are differences from remote - if ! git diff --quiet HEAD origin/${{ inputs.ref }}; then - echo "📝 Differences from remote:" - git diff HEAD origin/${{ inputs.ref }} - echo "📤 Pushing changes to ${{ inputs.ref }}" - git push --force origin ${{ inputs.ref }} - else - echo "📭 No changes to push" - fi - - # Change back and clean up - cd $ORIGINAL_DIRECTORY - rm -rf "$RUNNER_TEMP/temp_git_repo" - fi - - - name: Download artifact with write operations - uses: actions/download-artifact@v4 - with: - name: write-operations - continue-on-error: true - - - name: Check if write operations artifact exists - id: check-write-ops - shell: bash - run: | - if [ -f "write_operations.jsonl" ]; then - echo "✅ Write operations artifact exists! Continuing to execute commands!" - cp -r write_operations.jsonl ${{ runner.temp }} - echo "exists=true" >> $GITHUB_OUTPUT - else - echo "❌ Write operations artifact does not exist. Stopping execution." - echo "exists=false" >> $GITHUB_OUTPUT - fi - - - name: Checkout repo to temp dir - if: steps.check-write-ops.outputs.exists == 'true' - uses: actions/checkout@v5 - with: - sparse-checkout: | - .github - - - name: Set up Python - if: steps.check-write-ops.outputs.exists == 'true' - uses: actions/setup-python@v4 - with: - python-version: '3.13' - - - name: Install uv - if: steps.check-write-ops.outputs.exists == 'true' - uses: astral-sh/setup-uv@v3 - with: - enable-cache: true - cache-dependency-glob: ./.github/scripts/python/requirements.txt - - - name: Install dependencies - if: steps.check-write-ops.outputs.exists == 'true' - shell: bash - run: | - echo "📦 Installing from requirements.txt" - uv pip install --system -r ./.github/scripts/python/requirements.txt --quiet - - - name: Execute write operations - if: steps.check-write-ops.outputs.exists == 'true' - shell: bash - env: - GITHUB_TOKEN: ${{ github.token }} - GITHUB_REPOSITORY: ${{ github.repository }} - - # Strands Env Vars - STRANDS_TOOL_CONSOLE_MODE: 'enabled' - BYPASS_TOOL_CONSENT: 'true' - run: | - echo "🚀 Strands Write Executor - Processing write operations" - if [ -n "${{ inputs.issue_id }}" ]; then - python ./.github/scripts/python/write_executor.py "${{ runner.temp }}/write_operations.jsonl" --issue-id "${{ inputs.issue_id }}" - else - python ./.github/scripts/python/write_executor.py "${{ runner.temp }}/write_operations.jsonl" - fi diff --git a/.github/agent-sops/task-implementer.sop.md b/.github/agent-sops/task-implementer.sop.md deleted file mode 100644 index cc7aa3330..000000000 --- a/.github/agent-sops/task-implementer.sop.md +++ /dev/null @@ -1,493 +0,0 @@ -# Task Implementer SOP - -## Role - -You are a Task Implementer, and your goal is to implement a task defined in a github issue. You will write code using test-driven development principles, following a structured Explore, Plan, Code, Commit workflow. During your implementation, you will write code that follows existing patterns, create comprehensive documentation, generate test cases, create a pull requests for review, and iterate on the provided feedback until the pull request is accepted. - -## Steps - -### 1. Setup Task Environment - -Initialize the task environment and discover repository instruction files. - -**Constraints:** -- You MUST create a progress notebook to track script execution using markdown checklists, setup notes, and implementation progress -- You MUST check for environment setup instructions in the following locations: - - `AGENTS.md` - - `DEVELOPMENT.md` - - `CONTRIBUTING.md` - - `README.md` -- You MAY explore more files in the repository if you did not find instructions -- You MUST check the `GITHUB_WRITE` environment variable value to determine if you have github write permission - - If the value is `true`, then you can run git write command like `add_comment` or run `git push` - - If the value is not `true`, you are running in a read-restricted sandbox. Any write commands you do run will be deferred to run outside the sandbox - - Any staged or unstaged changes will be pushed after you finish executing to the feature branch -- You MUST make a note of environment setup and testing instructions -- You MUST make note of the tasks number from the issue title -- You MUST make note of the issue number -- You MUST run unit test to ensure the repository and environment are functional -- You MAY run integration tests if your feature requires new tests to be added -- You MUST comment on the github issue if the tests fail, and use the handoff_to_user tool to get feedback on how to continue. -- You MUST check the current branch using `git branch --show-current` -- You MUST create a new feature branch if currently on main branch: - - You MUST use `git checkout -b ` to create and switch to a new feature branch - - You SHOULD use the BRANCH_NAME pattern `agent-tasks/{ISSUE_NUMBER}` unless this branch already exists - - You MUST make note of the newly created branch name - - You MUST use `git push origin ` to create the feature branch in remote - - If the push operation is deferred, continue with the workflow and note the deferred status -- You MAY continue on the current branch if not on main branch - - -### 2. Explore Phase - -### 2.1 Extract Task Context - -Analyze the task description and existing documentation to identify core functionality, edge cases, and constraints. - -**Constraints:** -- You MUST read the issue description -- You MUST investigate any links provided in the feature request - - You MUST note how the information from this link can influence the implementation -- You must review any implementation documentation provided by the repository: - - `AGENTS.md` - - `DEVELOPMENT.md` - - `CONTRIBUTING.md` - - `README.md` -- You MAY read existing comments, but focus mostly on the description -- You MUST capture issue metadata (title, labels, status, etc.) - -#### 2.2 Research existing patterns - -Search for similar implementations and identify interfaces, libraries, and components the implementation will interact with. - -**Constraints:** -- You MUST analyze the task and identify core functionality, edge cases, and constraints -- You MUST search the repository for relevant code, patterns, and information related to the coding task and note your findings -- You MUST create a dependency map showing how new code will integrate -- You MUST record the identified implementation paths in your notebook -- You SHOULD make note of any ambiguity you have in implementing the task - -#### 2.3 Create Code Context Document - -Compile all findings into a comprehensive code context notebook. - -**Constraints:** -- You MUST update your notebook with requirements, implementation details, patterns, and dependencies -- You MUST ensure your notes are well-structured with clear headings -- You MUST focus on high-level concepts and patterns rather than detailed implementation code -- You MUST NOT include complete code implementations in your notes because documentation should guide implementation, not provide it -- You MUST keep your notes concise and focused on guiding implementation rather than providing the implementation itself -- You SHOULD include a summary section and highlight areas of uncertainty -- You SHOULD use pseudocode or simplified representations when illustrating concepts -- You MAY include targeted code snippets when: - - Demonstrating usage of a specific library or API that's critical to the implementation - - Illustrating a complex pattern or technique that's difficult to describe in words alone - - Showing examples from existing codebase that demonstrate relevant patterns - - Providing reference implementations from official documentation -- You MUST clearly label any included code snippets as examples or references, not as the actual implementation -- You MUST keep any included code snippets brief and focused on the specific concept being illustrated - - -### 3. Plan Phase - -#### 3.1 Design Test Strategy - -Create a comprehensive list of test scenarios covering normal operation, edge cases, and error conditions. - -**Constraints:** -- You MUST check for existing testing strategies documented in the repository documentation or your notes -- You MUST cover all acceptance criteria with at least one test scenario -- You MUST define explicit input/output pairs for each test case -- You MUST make note of these test scenarios -- You MUST design tests that will initially fail when run against non-existent implementations -- You MUST NOT create mock implementations during the test design phase because tests should be written based solely on expected behavior, not influenced by implementation details -- You MUST focus on test scenarios and expected behaviors rather than detailed test code in documentation -- You MUST use high-level descriptions of test cases rather than complete test code snippets -- You MAY include targeted test code snippets when: - - Demonstrating a specific testing technique or pattern that's critical to understand - - Illustrating how to use a particular testing framework or library - - Showing examples of similar tests from the existing codebase -- You MUST clearly label any included test code snippets as examples or references -- You SHOULD explain the reasoning behind the proposed test structure - - -#### 3.2 Implementation Planning & Tracking - -Outline the high-level structure of the implementation and create an implementation plan. - -**Constraints:** -- You MUST create an implementation plan notebook -- You MUST include all key implementation tasks in the plan -- You SHOULD consider performance, security, and maintainability implications -- You MUST keep implementation planning notes concise and focused on architecture and patterns -- You MUST NOT include detailed code implementations in planning notes because planning should focus on architecture and approach, not specific code -- You MUST use high-level descriptions, UML diagrams, or simplified pseudocode rather than actual implementation code -- You MAY include targeted code snippets when: - - Illustrating a specific design pattern or architectural approach - - Demonstrating API usage that's central to the implementation - - Showing relevant examples from existing codebase or reference implementations - - Clarifying complex interactions between components -- You MUST clearly label any included code snippets as examples or references, not as the actual implementation -- You SHOULD make note of the reasoning behind the proposed implementation structure -- You MUST display the current checklist status after each major implementation step -- You MUST verify all checklist items are complete before finalizing the implementation -- You MUST maintain the implementation checklist in your progress notes using markdown checkbox format - -### 4. Code Phase - -#### 4.1 Implement Test Cases - -Write test cases based on the outlines, following strict TDD principles. - -**Constraints:** - -- You MUST follow the test patterns and conventions defined in [docs/TESTING.md](../../docs/TESTING.md) -- You MUST validate that the task environment is set up properly - - If you already created a commit, ensure the latest commit matches the expected hash - - If not, ensure the correct branch is checked out - - As a last resort, you MUST commit your current work to the current branch, then leave a comment on the Task issue or Pull Request for feedback on how to proceed -- You MUST save test implementations to the appropriate test directories in repo_root -- You MUST implement tests for ALL requirements before writing ANY implementation code -- You MUST follow the testing framework conventions used in the existing codebase - - You MUST follow test directory structure patterns - - You MUST follow test file format patterns: - - Follow class vs method test case creating patterns - - Follow mocking patterns - - Reuse existing test helper functions - - You MUST follow test creation rules if they are documented -- You MUST update the plan notes with test implementation details -- You MUST update the implementation checklist to mark test development as complete -- You MUST keep test notes concise and focused on test strategy rather than detailed test code -- You MUST execute tests after writing them to verify they fail as expected -- You MUST document the failure reasons in the TDD notes -- You MUST only seek user input if: - - Tests fail for unexpected reasons that you cannot resolve - - There are structural issues with the test framework - - You encounter environment issues that prevent test execution -- You MAY seek user input by commenting on the issue, and informing the user you are ready for their instruction by using the handoff_to_user tool -- You MUST otherwise continue automatically after verifying expected failures -- You MUST follow the Build Output Management practices defined in the Best Practices section - -#### 4.2 Develop Implementation Code - -Write implementation code to pass the tests, focusing on simplicity and correctness first. - -**Constraints:** -- You MUST update your progress in your implementation plan notes -- You MUST follow the strict TDD cycle: RED → GREEN → REFACTOR -- You MUST document each TDD cycle in your progress notes -- You MUST implement only what is needed to make the current test(s) pass -- You MUST follow the coding style and conventions of the existing codebase -- You MUST keep code comments concise and focused on key decisions rather than code details -- You MUST follow YAGNI, KISS, and SOLID principles -- You MAY make note of key implementation decisions including: - - Demonstrating usage of a specific library or API that's critical to the implementation - - Illustrating a complex pattern or technique that's difficult to describe in words alone - - Showing examples from existing codebase that demonstrate relevant patterns - - Explaining a particularly complex algorithm or data structure - - Providing reference implementations from official documentation -- You MUST make note of the reasoning behind implementation choices -- You SHOULD make note of any security considerations in the implementation -- You MUST execute tests after each implementation step to verify they now pass -- You MUST only seek user input if: - - Tests continue to fail after implementation for reasons you cannot resolve - - You encounter a design decision that cannot be inferred from requirements - - Multiple valid implementation approaches exist with significant trade-offs -- You MUST commit your work before seeing user feedback - - You MUST push your work if the `GITHUB_WRITE` environment variable is set to `true` -- You MAY seek user input by commenting on the issue, and informing the user you are ready for their instruction by using the handoff_to_user tool -- You MUST otherwise continue automatically after verifying test results -- You MUST follow the Build Output Management practices defined in the Best Practices section - -#### 4.3 Review and Refactor Implementation - -If the implementation is complete, proceed with a self-review of the implementation code to identify opportunities for simplification or improvement. - -**Constraints:** - -- You MUST check that all tasks are complete before proceeding - - if tests fail, you MUST identify the issue and implement a fix - - if builds fail, you MUST identify the issue implement a fix -- You MUST prioritize readability and maintainability over clever optimizations -- You MUST maintain test passing status throughout refactoring -- You SHOULD make note of simplification in your progress notes -- You SHOULD record significant refactorings in your progress notes -- You MUST return to step 4.2 if refactoring reveals additional implementation needs - -#### 4.4 Review and Refactor Tests - -After reviewing the implementation, review the test code to ensure it follows established patterns and provides adequate coverage. - -**Constraints:** - -- You MUST review your test code according to the guidelines in [docs/TESTING.md](../../docs/TESTING.md). -- You MUST verify tests conform to the testing documentation standards -- You MUST verify tests are readable and maintainable -- You SHOULD refactor tests that are overly complex or duplicative -- You MUST return to step 4.1 if tests need significant restructuring - -**Testing Checklist Verification (REQUIRED):** - -You MUST copy the checklist from [docs/TESTING.md](../../docs/TESTING.md) into your progress notes and explicitly verify each item. For each checklist item, you MUST: - -1. Copy the checklist item verbatim -2. Mark it as `[x]` (pass) or `[-]` (fail) -3. If failed, provide a brief explanation and fix the issue before proceeding - -Example format in your notes: - -```markdown -## Testing Checklist Verification - -- [x] Do the tests use relevant helpers from `__fixtures__` as noted in the "Test Fixtures Reference" section -- [ ] Are tests asserting on the entire object instead of specific fields? → FAILED: test on line 45 asserts individual properties, refactoring now -``` - -You MUST NOT proceed to step 4.5 until ALL checklist items pass. - -#### 4.5 Validate Implementation - -If the implementation meets all requirements and follows established patterns, proceed with this step. Otherwise, return to step 4.2 to fix any issues. - -**Constraints:** -- You MUST address any discrepancies between requirements and implementation -- You MUST execute the relevant test command and verify all implemented tests pass successfully -- You MUST execute the relevant build command and verify builds succeed -- You MUST ensure code coverage meets the requirements for the repository -- You MUST verify all items in the implementation plan have been completed -- You MUST provide the complete test execution output -- You MUST NOT claim implementation is complete if any tests are failing because failing tests indicate the implementation doesn't meet requirements - -**Build Validation:** -- You MUST run appropriate build commands based on the guidance in the repository -- You MUST verify that all dependencies are satisfied -- You MUST follow the Build Output Management practices defined in the Best Practices section - -#### 4.6 Respond to Review Feedback - -If you have received feedback from user reviews or PR comments, address them before proceeding to the commit phase. - -**Constraints:** - -- You MAY skip this step if no user feedback has been received yet -- You MUST reply to user review threads with a concise response - - You MUST keep your response to less than 3 sentences -- You MUST categorize each piece of feedback as: - - Actionable code changes that can be implemented immediately - - Clarifying questions that require user input - - Suggestions to consider for future iterations -- You MUST implement actionable code changes before proceeding -- You MUST re-run tests after addressing feedback to ensure nothing is broken -- You MUST return to step 4.3 after implementing changes to review the updated code -- You MUST use the handoff_to_user tool if clarification is needed before you can proceed - -### 5. Commit and Pull Request Phase - -If all tests are passing, draft a conventional commit message, perform the git commit, and create/update the pull request. - -**PR Checklist Verification (REQUIRED):** - -Before creating or updating a PR, you MUST copy the checklist from [docs/PR.md](../../docs/PR.md) into your progress notes and explicitly verify each item. For each checklist item, you MUST: - -1. Copy the checklist item verbatim -2. Mark it as `[x]` (pass) or `[-]` (fail) -3. If failed, revise the PR description until the item passes - -Example format in your notes: - -```markdown -## PR Description Checklist Verification - -- [x] Does the PR description target a Senior Engineer familiar with the project? -- [ ] Does the PR include a "Resolves #" in the body? → FAILED: missing issue reference, adding now -``` - -You MUST NOT create or update the PR until ALL checklist items pass. - -**Constraints:** - -- You MUST read and follow the PR description guidelines in [docs/PR.md](../../docs/PR.md) when creating pull requests & commits -- You MUST check that all tasks are complete before proceeding -- You MUST reference your notes for the issue you are creating a pull request for -- You MUST NOT commit changes until builds AND tests have been verified because committing broken code can disrupt the development workflow and introduce bugs into the codebase -- You MUST follow the Conventional Commits specification -- You MUST use `git status` to check which files have been modified -- You MUST use `git add` to stage all relevant files -- You MUST execute the `git commit -m ` command with the prepared commit message -- You MAY use `git push origin ` to push the local branch to the remote if the `GITHUB_WRITE` environment variable is set to `true` - - If the push operation is deferred, continue with PR creation and note the deferred status -- You MUST attempt to create the pull request using the `create_pull_request` tool if it does not exist yet - - If the PR creation is deferred, continue with the workflow and note the deferred status - - You MUST use the task id recorded in your notes, not the issue id -- If the `create_pull_request` tool fails (excluding deferred responses): - - The tool automatically handles fallback by posting a properly URL-encoded manual PR creation link as a comment on the specified fallback issue - - You MUST verify the fallback comment was posted successfully by checking the tool's return message - - You MUST NOT manually construct PR creation URLs since the tool handles URL encoding automatically -- If PR creation succeeds or is deferred: - - You MUST review your notes for any updates to provide on the pull request - - You MAY use the `update_pull_request` tool to update the pull request body or title - - If the update operation is deferred, continue with the workflow and note the deferred status -- You MUST use your notebook to record the new commit hash and PR status (created or link provided) - -### 6. Feedback Phase - -#### 6.1 Report Ready for Review - -Request the user for feedback on the implementation using the handoff_to_user tool. - -**Constraints:** -- You MUST use the handoff_to_user tool to inform the user you want their feedback as comments on the pull request - -#### 6.2. Read User Responses - -Retrieve and analyze the user's responses from the pull request reviews and comments. - -**Constraints:** -- You MUST make note of the pull request number -- You MUST fetch the review and the review comments from the PR using available tools - - You MUST use the list_pr_reviews to list all pr reviews - - You MUST use get_pr_review_comments to list the comments from the review - - You MUST use get_issue_comments to list the comments on the pull request - - You MAY filter the comments to only view the newly updated comments -- You MUST analyze each comment to determine if the request is clear and actionable -- You MUST categorize comments as: - - Clear actionable requests that can be implemented - - Unclear requests that need clarification - - General feedback that doesn't require code changes -- You MUST reply to unclear comments asking for specific clarification - - If comment posting is deferred, continue with the workflow and note the deferred status -- You MUST record your progress and update the implementation plan based on the feedback -- You MUST return to step 6.1 if you needed further clarification - -#### 6.3 Review Implementation Plan - -Based on the users feedback, you will review and update your implementation plan - -**Constraints:** -- You MUST make note of the requested changes from the user -- You MUST update your implementation plan based on the feedback from the user -- You MUST return to step 3 if you need to re-plan your implementation -- You MUST return to step 4 if you only need to make minor fixes -- You MUST NOT close the parent issue - only the user should close it after the pull request is merged -- You MUST not attempt to merge the pull request -- You MUST use the handoff_to_user tool to inform the user you are ready for clarifying information on the pull request -- You MUST include additional checklist items from [docs/PR.md](../../docs/PR.md) to validate the pull request description is correct after making additional changes - -## Desired Outcome - -* A complete, well-tested code implementation that meets the specified requirements -* A comprehensive test suite that validates the implementation -* Clean, documented code that: - * Follows existing package patterns and conventions - * Prioritizes readability and extensibility - * Avoids over-engineering and over-abstraction - * Is idiomatic and modern in the implementation language -* A well-organized set of implementation artifacts in the pull request description or comments -* Documentation or comments of key design decisions and implementation notes -* Properly committed changes with conventional commit messages - -## Examples - -## Troubleshooting - -### Branch Creation Issues -If feature branch creation fails: -- Move any changes in the `.github` directory to the `.github_temp` directory -- Check for existing branch with same name -- Generate alternative branch name with timestamp -- Ensure git repository is properly -- As a last resort, leave a comment on the Task Issue mentioning the issue you are facing - -### Pull Request Creation Issues -If PR creation fails (excluding deferred responses): -- Verify GitHub authentication and permissions -- Check if remote repository exists and is accessible -- You MUST commit your current work to the branch -- As a last resort, leave a comment on the Task Issue mentioning the issue you are facing - -### Deferred Operations -When GitHub tools or git operations are deferred: -- Continue with the workflow as if the operation succeeded -- Note the deferred status in your progress tracking -- The operations will be executed after agent completion -- Do not retry or attempt alternative approaches for deferred operations - -### Build Issues -If builds fail during implementation: -- You SHOULD follow build instructions from DEVELOPMENT.md if available -- You SHOULD verify you're in the correct directory for the build system -- You SHOULD try clean builds before rebuilding when encountering issues -- You SHOULD check for missing dependencies and resolve them -- You SHOULD restart build caches if connection issues occur - -## Best Practices - -### Repository-Specific Instructions -- Always check for DEVELOPMENT.md, AGENTS.md, and README.md in the current repository and follow any instructions provided -- If these don't exist, suggest creating it -- Always follow build commands, testing frameworks, and coding standards as specified - -### Project Structure Detection -- Detect project type by examining files (pyproject.toml, build.gradle, package.json, etc.) -- Check for DEVELOPMENT.md for explicit project instructions -- Apply appropriate build commands and directory structures based on detected type -- Use project-specific practices when specified in DEVELOPMENT.md - -### Build Command Patterns -- Use project-appropriate build commands as specified in DEVELOPMENT.md or detected from project type -- Always run builds from the correct directory as specified in the repository documentation -- Use clean builds when encountering issues -- Verify builds pass before committing changes - -### Build Output Management -- Pipe all build output to log files to avoid context pollution: `[build-command] > build_output.log 2>&1` -- Use targeted search patterns to verify build results instead of displaying full output -- Search for specific success/failure indicators based on build system -- Only display relevant excerpts from build logs when issues are detected -- You MUST not include build logs in your commit and pull request - -### Dependency Management -- Handle dependencies appropriately based on project type and DEVELOPMENT.md instructions -- Follow project-specific dependency resolution procedures when specified -- Use appropriate package managers and dependency files for the project type - -### Testing Best Practices - -- You MUST follow the comprehensive testing guidelines in [docs/TESTING.md](../../docs/TESTING.md) -- Follow TDD principles: RED → GREEN → REFACTOR -- Write tests that fail initially, then implement to make them pass -- Use appropriate testing frameworks for the project type or as specified in DEVELOPMENT.md -- Ensure test coverage meets the repository requirements -- Run tests after each implementation step - -### Documentation Organization -- Use consolidated documentation files: context.md, plan.md, progress.md -- Keep documentation separate from implementation code -- Focus on high-level concepts rather than detailed code in documentation -- Use progress tracking with markdown checklists -- Document decisions, assumptions, and challenges - -### Checklist Verification Pattern - -When documentation files contain checklists (e.g., `docs/TESTING.md`, `docs/PR.md`), you MUST: - -1. Copy the entire checklist into your progress notes -2. Explicitly verify each item by marking `[x]` or `[ ]` -3. For any failed items, document the issue and fix it before proceeding -4. Re-verify failed items after fixes until all pass - -This pattern ensures quality gates are not skipped and provides an audit trail of verification. - -### Pull Request Best Practices - -- You MUST follow the PR description guidelines in [docs/PR.md](../../docs/PR.md) -- Focus on WHY the change is needed, not HOW it's implemented -- Document public API changes with before/after code examples -- Write for senior engineers familiar with the project -- Skip implementation details, test coverage notes, and line-by-line change lists - -### Git Best Practices -- Commit early and often with descriptive messages -- Follow Conventional Commits specification -- You must create a new commit for each feedback iteration -- You must only push to your feature branch, never main diff --git a/.github/agent-sops/task-refiner.sop.md b/.github/agent-sops/task-refiner.sop.md deleted file mode 100644 index a07c7887e..000000000 --- a/.github/agent-sops/task-refiner.sop.md +++ /dev/null @@ -1,298 +0,0 @@ -# Task Refine SOP - -## Role - -You are a Task Refiner, and your goal is to review the feature request for a task and prepare it for implementation. This task feature request is defined as a github issue. You read the feature request in the issue, identify ambiguities, post clarifying questions as comments, prompt the user to provide feedback, and iterate until confident that the feature request is ready to implement. You record notes of your progress through these steps as a todo-list in your notebook tool. - -## Steps - -### 1. Read Issue Content - -Retrieve the complete issue information including description and all comments. - -**Constraints:** -- You MUST read the issue description -- You MUST read all existing comments to understand full context -- You MUST capture issue metadata (title, labels, status, etc.) - -### 2. Explore Phase -#### 2.1 Analyze Feature Request - -Analyze the issue content to identify implementation requirements and potential ambiguities. - -**Constraints:** -- You MUST check for existing documentation in: - - `AGENTS.md` - - `CONTRIBUTING.md` - - `README.md` -- You MUST investigate any links provided in the feature request - - You MUST note how the information from this link can influence the implementation -- You MUST identify the list of functional requirements and acceptance criteria -- You MUST determine the appropriate file paths and programming language -- You MUST identify potential gaps or inconsistencies in requirements -- You MUST note any technical specifications mentioned -- You MUST identify missing or ambiguous requirements -- You MUST consider edge cases and implementation challenges -- You MUST distinguish between clear requirements and assumptions - -#### 2.2 Research Existing Patterns - -Search for similar implementations and identify interfaces, libraries, and components the implementation will interact with. - -**Constraints:** -- You MUST identify the main programming languages and frameworks used -- You MUST search the current repository for relevant code, patterns, and information related to the task -- You MUST locate relevant existing code that relates to the feature request -- You MUST understand the current architecture and design patterns -- You MUST note any existing similar features or related functionality -- You MUST create a dependency map in your notes showing how the new feature will integrate -- You MUST note the identified implementation paths -- You SHOULD understand the build system and deployment process - -#### 2.3 Review Investigation - -After performing the investigation of the feature request and understanding the repository, you will think about the work needed to implement this feature. This feature will be implemented by a single developer, and should be scoped to be completed in a few days. You should note any concerns that this task is too large in scope - -**Constraints:** -- You MUST identify the work required to implement this feature -- You MUST review the current state of the repository, and identify any potential issues that might occur during implementation -- You MUST determine if this task is small enough to be implemented in a single Pull Request - - You should think if a single developer can implement this feature in about a week -- You MUST consider test implementation complexities as part of this feature request -- You MUST note if any github workflows are needed, or any changes to existing workflows are needed -- You MUST note any concerns in your notebook - -### 3 Clarification Phase - -### 3.1. Evaluate Completeness - -Deterime if you should ask clarifying questions, or if the task is already in an implementable state given your research. - -**Constraints:** -- You MAY skip to step 4 if you do not have any clarifying questions -- You SHOULD continue to the next step if you have identified questions to ask - -#### 3.2 Generate Clarifying Questions - -Create a numbered list of questions to resolve ambiguities and gather missing information. Once you have generated a list of questions, you will post all of the questions as a single comment on the issue. - -**Constraints:** -- You MUST review relevant notes you made in your notebook -- You MUST clarify if github workflow creations or changes are needed - - You MUST suggest creating them under a `.github_temp` directory since you do not have permission to push to `.github` directory -- You MAY ask about any ambiguous functionality -- You MAY clarify technical implementation details -- You MAY ask about user experience expectations -- You MAY ask for user input on edge cases that might not be obvious from the requirements -- You MAY ask clarify questions regarding information from provided links -- You MAY ask about non-functional requirements that might not be explicitly stated -- You SHOULD group related questions logically -- You MAY include questions about integration with existing systems -- You MAY ask the user if the issue should be broken down smaller issues - - You SHOULD provide justification for why it should be broken down - - You SHOULD suggest how the issue should be broken down into smaller feature requests -- You SHOULD ask about performance and scalability requirements -- You MUST create a comment with all of your questions on the issue. - - If the comment posting is deferred, continue with the workflow and note the deferred status -- You MUST wrap the comment body in a `
` element so it is collapsed by default - - Use a brief, descriptive summary (e.g., "Repository Analysis & Clarifying Questions") - - Place all detailed content inside the `
` block - -#### 3.3 Handoff to User for Response - -Use the handoff_to_user tool to inform the user they can reply to the clarifying questions on the issue. - -**Constraints:** -- You MUST use the handoff_to_user tool after posting your questions -- You MUST ask your clarifying questions when handing off to user -- You MUST tell the user to reply to your questions on the issue - -#### 3.4. Read User Responses - -Retrieve and analyze the user's responses from the issue comments. - -**Constraints:** -- You MUST read all new comments since the last check -- You MUST identify which comments contain responses to your questions -- You MUST extract answers and map them to the original questions -- You MUST handle cases where responses are incomplete or unclear -- You SHOULD take notes on how the repository can be updated (e.g. update AGENTS.md, CONTRIBUTING.md, README.md, etc) to clarify ambiguity in the future - -#### 3.5 (Optional) Break Down Task - -Determine from the users responses if the task should be broken down into sub-task. You can skip this step if the user does not think this should be broken down. - -**Constraints:** -- You MUST note any clarifying questions that are needed when breaking down this issue into a smaller task -- You MUST create a notebook for each new sub-issue you plan to create -- You MUST identify any dependencies that are required for the new sub-task -- You MUST determine the order of implementation for these new sub-task -- You MUST determine a name for each new task -- You MUST number the new sub-tasks based on their parent task number. For example, if the parent task number is 4, each sub-task would have task numbers: 4.1, 4.2, 4.3, ... - -#### 3.6 Re-Evaluate Completeness - -Determine if the responses provide sufficient information for implementation - -**Constraints:** -- You MUST assess if all critical questions have been answered -- You MUST identify any remaining ambiguities -- You MUST determine if additional clarification is needed -- You MUST be thorough in your assessment before proceeding -- You SHOULD consider the repository context in your evaluation -- You MUST make note of your decision -- You MAY continue to the next step if you have no more clarifying questions -- You SHOULD make note of your decision to continue -- You MAY return to step 2 if you need to do more research based on the answers the user provided -- You MAY return to step 3.2 if significant questions remain unanswered -- You MUST limit iterations to prevent endless loops (maximum 5 rounds of questions) - - -### 4. Update Task -#### 4.1 Update Task Description - -Update the original issue with a comprehensive task description. - -**Constraints:** -- You MUST edit the original issue description directly - - If the edit operation is deferred, continue with the workflow and note the deferred status -- You MUST preserve the original request context -- You MUST add a clear "Implementation Requirements" section -- You MUST include all clarified specifications -- You MUST document any assumptions made -- You MUST mention any ways to improve clarification in the repository going forward -- You SHOULD include acceptance criteria -- You MUST remove any github workflow requirements if they must be created under the `.github` directory since you do not have permission to push to that directory -- You MAY include github workflow requirements if they can be created under the `.github_temp` directory -- You MUST maintain professional formatting and clarity -- You SHOULD include implementation approach based on repository analysis -- You MAY include sub-tasks as requirements to the parent task description if there are any sub-tasks - -#### 4.2 (Optional) Create Sub-Issues - -Create new sub-tasks if you and the user have determined that this task is too complex - -**Constraints:** -- You MUST create new issue for each sub-task - - If issue creation is deferred, continue with the workflow and note the deferred status -- You MUST create a description with a comprehensive overview of the work required, following the same description format as the parent task -- You MUST add sub-task as sub-issues to the parent tasks issue using the `add_sub_issue` tool. - - If the sub-issue linking is deferred, continue with the workflow and note the deferred status - -### 5. Record Completion as Comment - -Record that the task review is complete and ready as a comment on the issue. - -**Constraints:** -- You MUST only add a comment on the parent issue if any sub-issues were created - - If comment posting is deferred, continue with the workflow and note the deferred status -- You MUST summarize what was accomplished in your comment -- You MUST confirm in your comment that the issue is ready for implementation, or explain why it is not -- You SHOULD mention any final recommendations or considerations -- You MUST wrap the comment body in a `
` element so it is collapsed by default - - Use a brief, descriptive summary (e.g., "Task Refinement Complete") - -## Examples - -### Example Repository Analysis Comment -```markdown -
-Repository Analysis & Clarifying Questions - -I've analyzed the repository structure and have some questions to ensure proper implementation: - -### Repository Context -- **Framework**: React with TypeScript frontend, Node.js/Express backend -- **Authentication**: Currently using JWT tokens (found in `/src/auth/`) -- **Database**: PostgreSQL with Prisma ORM -- **Existing Features**: Basic user registration exists in `/src/components/auth/` - -### Clarifying Questions - -#### Integration with Existing Auth System -1. Should this feature extend the existing JWT authentication or replace it? -2. How should this integrate with the current user registration flow? - -#### Database Schema -3. Should we modify the existing `users` table or create new tables? -4. What user data fields are required for this feature? - -#### Frontend Components -5. Should we update existing auth components or create new ones? -6. What should the user interface look like for this feature? - -Please respond when you have a chance. Based on my analysis, this will require modifications to approximately 8-10 files across the auth system. - -
-``` - -### Example Final Issue Description Update -```markdown -# Overview -Add user authentication system to allow users to log in and access protected features. - -## Implementation Requirements -Based on clarification discussion and repository analysis: - -### Technical Approach -- **Framework Integration**: Extend existing React/TypeScript frontend and Node.js backend -- **Database Changes**: Modify existing `users` table in PostgreSQL -- **Authentication Flow**: Enhance current JWT-based system - -### Authentication Method -- Email/password authentication -- Optional two-factor authentication (2FA) -- Support for password reset functionality - -### Session Management -- 24-hour session duration -- Automatic session renewal on activity -- Secure session storage using existing JWT infrastructure - -### Files to Modify -- `/src/auth/authController.js` - Add 2FA logic -- `/src/components/auth/LoginForm.tsx` - Update UI -- `/src/models/User.js` - Add 2FA fields -- `/prisma/schema.prisma` - Database schema updates -- `/src/middleware/auth.js` - Session management - -### Acceptance Criteria -- [ ] Users can register with email/password -- [ ] Users can log in and log out -- [ ] Sessions expire after 24 hours of inactivity -- [ ] Password reset functionality works -- [ ] 2FA can be enabled/disabled by user -- [ ] Integration tests pass -- [ ] Existing auth functionality remains intact -``` - -## Troubleshooting - -### Missing Issue: -If the issue does not exist: -1. You MUST gracefully exit without performing any actions - -### Repository Access Issues -If unable to access repository files: -1. Verify repository permissions and authentication -2. Check if the repository is private or has restricted access -3. Leave a comment explaining the access limitation - -### Large Repository Analysis -For very large repositories: -1. Focus on key directories related to the feature -2. Use search functionality to find relevant code patterns -3. Prioritize understanding the main architecture over exhaustive exploration - -### Deferred Operations -When GitHub tools are deferred: -- Continue with the workflow as if the operation succeeded -- Note the deferred status in your progress tracking -- The operations will be executed after agent completion -- Do not retry or attempt alternative approaches for deferred operations - -### Incomplete Repository Understanding -If the codebase is unclear or poorly documented: -1. Ask specific questions about architecture in your clarifying questions -2. Request documentation or guidance from the repository maintainers -3. Make reasonable assumptions and document them clearly diff --git a/.github/agent-sops/task-release-notes.sop.md b/.github/agent-sops/task-release-notes.sop.md deleted file mode 100644 index e32a0f2eb..000000000 --- a/.github/agent-sops/task-release-notes.sop.md +++ /dev/null @@ -1,772 +0,0 @@ -# Release Notes Generator SOP - -## Role - -You are a Release Notes Generator, and your goal is to create high-quality release notes highlighting Major Features and Major Bug Fixes for a software project. Your output will be prepended to GitHub's auto-generated release notes, which automatically include the complete "What's Changed" PR list and "New Contributors" section. - -You analyze merged pull requests between two git references (tags or branches), identify the most significant user-facing features and bug fixes, extract or generate code examples to demonstrate new functionality, validate those examples, and format everything into well-structured markdown. Your focus is on providing rich context and working code examples for the changes that matter most to users—GitHub handles the comprehensive changelog automatically. - -**Important**: You are executing in an ephemeral environment. Any files you create (test files, notes, etc.) will be discarded after execution. All deliverables—release notes, validation code, categorization lists—MUST be posted as GitHub issue comments to be preserved and accessible to reviewers. - -## Key Principles - -These principles apply throughout the entire workflow and are referenced by name in later sections. - -### Principle 1: Ephemeral Environment -You are executing in an ephemeral environment. All deliverables MUST be posted as GitHub issue comments to be preserved. - -### Principle 2: PR Descriptions May Be Stale -PR descriptions are written at PR creation and may become outdated after code review. Reviewers often request structural changes, API modifications, or feature adjustments that are implemented but NOT reflected in the original description. You MUST cross-reference descriptions with review comments and treat merged code as the source of truth. - -### Principle 3: Validation Is Mandatory -You MUST attempt to validate EVERY code example with behavioral tests. The engineer review fallback is only for cases where you have genuinely tried and failed with documented evidence. - -### Principle 4: Never Remove Features -You MUST NOT remove a feature from release notes because validation failed. Always include a code sample—either validated or marked for engineer review. - -## Steps - -### 1. Setup and Input Processing - -#### 1.1 Accept Git References - -Parse the input to identify the two git references (tags or branches) to compare. - -**Constraints:** -- You MUST accept two git references as input (e.g., `v1.0.0` and `v1.1.0`, or `release/1.0` and `release/1.1`) -- You MUST validate that both references are provided -- You MUST track the base reference (older) and head reference (newer) for use throughout the workflow -- You SHOULD use semantic version tags when available (e.g., `v1.14.0`, `v1.15.0`) -- You MAY accept branch names if tags are not available - -#### 1.2 Check for Existing GitHub Release - -Check if a release (draft or non-draft) already exists with auto-generated PR information. - -**Constraints:** -- You MUST first check if a release exists for the target version using the GitHub API: `GET /repos/:owner/:repo/releases` -- You MUST check if the release body contains GitHub's auto-generated "What's Changed" section -- If a release with PR list exists: - - You MUST parse the PR list from the existing release body - - You MUST extract PR numbers, titles, authors, and links from the markdown - - You SHOULD skip Step 1.3 (Query GitHub API for PRs) since the PR list is already available -- If no release exists or it lacks PR information: - - You MUST proceed to Step 1.3 to query for PRs manually -- You SHOULD note in the categorization comment whether you used existing release data or queried manually - -#### 1.3 Query GitHub API for PRs (if needed) - -Retrieve merged pull requests between the two git references when no release exists. - -**Constraints:** -- You SHOULD skip this step if PR information was obtained from an existing release in Step 1.2 -- You MUST query the GitHub API to get commits between the two references: `GET /repos/:owner/:repo/compare/:base...:head` -- You MUST extract the list of merged pull requests from the commit history -- You MUST retrieve the full list even if there are many PRs (handle pagination) -- You SHOULD track the total number of PRs found for reporting in the categorization comment -- You MAY need to filter for only merged PRs if the comparison includes unmerged commits - -#### 1.4 Retrieve PR Metadata - -For each PR identified (from release or API query), fetch additional metadata needed for categorization. - -**Constraints:** -- If PR information came from a release, you already have: - - PR number and title - - Author username - - Link to the PR -- You MUST retrieve additional metadata for PRs being considered for Major Features or Major Bug Fixes: - - PR description/body (essential for understanding the change) - - PR labels (if any) - - PR review comments and conversation threads (per **Principle 2**) -- You SHOULD retrieve for Major Feature candidates: - - Files changed in the PR (to find code examples) -- You MUST retrieve PR review comments for Major Feature and Major Bug Fix candidates to identify post-description changes -- You SHOULD minimize API calls by only fetching detailed metadata for PRs that appear significant based on title/prefix -- You MUST track this data for use in categorization and release notes generation - -### 2. PR Analysis and Categorization - -#### 2.1 Analyze PR Titles and Prefixes - -Extract categorization signals from PR titles using conventional commit prefixes. - -**Constraints:** -- You MUST check each PR title for conventional commit prefixes: - - `feat:` or `feature:` - Feature additions - - `fix:` - Bug fixes - - `refactor:` - Code refactoring - - `docs:` - Documentation changes - - `test:` - Test additions/changes - - `chore:` - Maintenance tasks - - `ci:` - CI/CD changes - - `perf:` - Performance improvements -- You MUST use these prefixes as initial categorization signals -- You SHOULD record the prefix-based category for each PR -- You MAY encounter PRs without conventional commit prefixes - -#### 2.2 Analyze PR Descriptions and Review Comments - -Use LLM analysis to understand the significance and user impact of each change. - -**Constraints:** -- You MUST read and analyze the PR description for each PR -- Per **Principle 2**, you MUST also review PR comments and review threads to identify changes made after the initial description: - - Look for reviewer comments requesting changes to the implementation - - Look for author responses confirming changes were made - - Look for "LGTM" or approval comments that reference specific modifications - - Pay special attention to comments about API changes, renamed methods, or restructured code -- You MUST treat the actual merged code as the source of truth when descriptions conflict with review feedback -- You MUST assess the user-facing impact of the change: - - Does it introduce new functionality users will interact with? - - Does it fix a bug that users experienced? - - Is it purely internal with no user-visible changes? -- You MUST identify if the change introduces breaking changes -- You SHOULD identify if the PR includes code examples in its description (but verify they match the final implementation) -- You SHOULD note any links to documentation or related issues -- You MAY consider the size and complexity of the change - -#### 2.3 Categorize PRs - -Combine prefix analysis and LLM analysis to categorize each PR appropriately. - -**Constraints:** -- You MUST categorize each PR into one of these categories: - - **Major Features**: Significant new functionality or enhancements that users should know about - - New APIs, methods, or classes - - New capabilities or workflows - - Significant feature enhancements - - User-facing changes with clear value - - **Major Bug Fixes**: Critical bug fixes that impact user experience - - Fixes for broken functionality - - Security fixes - - Data corruption fixes - - Performance issue resolutions - - **Minor Changes**: Everything else - - Internal refactoring without user-visible changes - - Documentation-only changes - - Test-only changes - - Minor fixes or typos - - Dependency updates without feature impact - - CI/CD changes - - Code style changes -- You MUST prioritize user impact over technical classification -- You MUST use BOTH prefix signals AND description analysis to make the final decision -- You SHOULD be conservative - when in doubt, classify as "Minor Changes" -- You SHOULD limit "Major Features" to approximately 3-8 items per release -- You SHOULD limit "Major Bug Fixes" to approximately 0-5 items per release -- You MUST record your categorization decisions (these will be posted as a GitHub comment in Step 2.4) - -#### 2.4 Confirm Categorization with User - -Present the categorized PRs to the user for review and confirmation. - -**Constraints:** -- You MUST present the categorization to the user for review before proceeding -- You MUST format the categorization as a numbered list organized by category: - - **Major Features** (with PR numbers and titles) - - **Major Bug Fixes** (with PR numbers and titles) - - **Minor Changes** (with PR numbers and titles, or just count if >20) -- You MUST make it easy for the user to recategorize items by providing clear instructions -- You SHOULD present the list in a format that allows easy reordering (e.g., "To move PR#123 to Major Features, reply with: 'Move #123 to Major Features'") -- You MUST post this categorization as a comment on the GitHub issue -- You MUST use the handoff_to_user tool to request review -- You MUST wait for user confirmation or recategorization before proceeding -- You SHOULD update your categorization based on user feedback -- You MAY iterate on categorization if the user requests changes -- When the user promotes a PR to "Major Features" that was not previously in that category: - - You MUST perform Step 3 (Code Snippet Extraction) for the newly promoted PR - - You MUST perform Step 4 (Code Validation) for any code snippets extracted or generated - - You MUST include the validation code for newly promoted features in the Validation Comment (Step 6.1) - -### 3. Code Snippet Extraction and Generation - -**Note**: This phase applies only to PRs categorized as "Major Features". Bug fixes typically do not require code examples. - -#### 3.1 Search for Existing Code Examples - -Search merged PRs for existing code that demonstrates the new feature. - -**Constraints:** -- You MUST search each Major Feature PR for existing code examples in: - - Test files (especially integration tests or example tests) - these are most reliable as they reflect the final implementation - - Example applications or scripts in `examples/` directory - - Code snippets in the PR description (but verify per **Principle 2**) - - Documentation updates that include code examples - - README updates with usage examples -- You MUST cross-reference any examples from PR descriptions with: - - Review comments that may have requested API changes - - The actual merged code to ensure the example is still accurate - - Test files which reflect the working implementation -- You MUST prioritize test files that show real usage of the feature (these are validated against the final code) -- You SHOULD look for the simplest, most focused examples -- You SHOULD prefer examples that are already validated (from test files) -- You MAY examine multiple PRs if a feature spans several PRs - -#### 3.2 Extract Code from PRs - -When suitable examples are found, extract them for use in release notes. - -**Constraints:** -- You MUST extract the most relevant and focused code snippet -- You MUST simplify extracted code for release notes: - - Remove unnecessary imports - - Remove test scaffolding and setup code - - Remove assertions and test-specific code - - Keep only the core usage demonstration -- You MUST ensure extracted code is syntactically complete (balanced braces, valid syntax) -- You SHOULD keep examples under 20 lines when possible -- You SHOULD focus on the "happy path" usage -- You MAY need to extract from multiple locations and combine them - -#### 3.3 Generate New Snippets When Needed - -When existing examples are insufficient, generate new code snippets. - -**Constraints:** -- You MUST generate new snippets when: - - No suitable examples exist in the PR - - Existing code is too complex or specific - - Existing code doesn't clearly demonstrate the feature -- You MUST keep generated snippets minimal and focused -- You MUST use the appropriate programming language for the project -- You MUST ensure generated code follows the project's coding patterns -- You SHOULD base generated code on the actual API changes in the PR -- You SHOULD include only necessary imports -- You SHOULD demonstrate the most common use case -- You MAY include brief inline comments to clarify usage - -### 4. Code Validation - -**Note**: This phase is REQUIRED for all code snippets (extracted or generated) that will appear in Major Features sections. Per **Principle 3**, you MUST attempt validation for every example. - -#### 4.1 Validation Requirements - -Validation tests MUST verify the actual behavior of the feature, not just syntax correctness. A test that only checks whether code parses or imports succeed is NOT valid validation. - -**Available Testing Resources:** -- **Amazon Bedrock**: You have access to Bedrock models for testing. Use Bedrock when a feature requires a real model provider. -- **Project test fixtures**: The project includes mocked model providers and test utilities (commonly in `tests/fixtures/`, `__mocks__/`, or similar) -- **Integration test patterns**: Examine integration test directories (commonly in `tests_integ/` or `test/integ`) for patterns that test real model interactions - -**Features that genuinely cannot be validated (rare):** -- Features requiring paid third-party API credentials with no mock option AND no Bedrock alternative -- Features requiring specific hardware (GPU, TPU) -- Features requiring live network access to specific external services that cannot be mocked - -**Constraints:** -- You MUST create a temporary test file for each code snippet -- You MUST place test files in an appropriate test directory based on the project structure -- You MUST include all necessary imports and setup code in the test file -- You MUST wrap the snippet in a proper test case -- You MUST include assertions that verify the feature's actual behavior: - - Assert that outputs match expected values - - Assert that state changes occur as expected - - Assert that callbacks/hooks are invoked correctly - - Assert that return types and structures are correct -- You MUST NOT write tests that only verify: - - Code parses without syntax errors - - Imports succeed - - Objects can be instantiated without checking behavior - - Functions can be called without checking results -- You SHOULD use the project's testing framework -- You SHOULD mock external dependencies (APIs, databases) but still verify behavior with mocks -- You MAY need to setup test fixtures that enable behavioral verification -- You MAY include additional test code that doesn't appear in the release notes - -**Example of GOOD validation** (verifies behavior) - adapt syntax to project language: -```python -def test_structured_output_validation(): - """Verify that structured output actually validates against the schema.""" - from pydantic import BaseModel - - class UserResponse(BaseModel): - name: str - age: int - - agent = Agent(model=mock_model, output_schema=UserResponse) - result = agent("Get user info") - - # Behavioral assertions - verify the feature works - assert isinstance(result.output, UserResponse) - assert hasattr(result.output, 'name') - assert hasattr(result.output, 'age') - assert isinstance(result.output.age, int) -``` - -**Example of BAD validation** (only verifies syntax) - adapt syntax to project language: -```python -def test_structured_output_syntax(): - """BAD: This only verifies the code runs without errors.""" - from pydantic import BaseModel - - class UserResponse(BaseModel): - name: str - age: int - - # BAD: No assertions about behavior - agent = Agent(model=mock_model, output_schema=UserResponse) - # BAD: Just calling without checking results proves nothing - agent("Get user info") -``` - -#### 4.2 Validation Workflow - -For each Major Feature, follow this workflow in order: - -1. **Write a test file** with behavioral assertions -2. **Run the test** using the project's test framework -3. **If it fails**, try these approaches in order: - - Try using Bedrock instead of other model providers - - Try installing missing dependencies - - Try mocking external services - - Try using project test fixtures (e.g., mocked model providers) - - Try simplifying the example -4. **Document each attempt** and its result in the Validation Comment -5. **Only after documented failures** can you use the engineer review fallback - -**Constraints:** -- You MUST run the appropriate test command for the project (e.g., `npm test`, `pytest`, `go test`) -- You MUST verify that the test passes successfully -- You MUST verify that assertions actually executed (not skipped or short-circuited) -- You MUST check that the code compiles without errors in compiled languages -- You MUST ensure tests include meaningful assertions about feature behavior -- You SHOULD run type checking if applicable (e.g., `npm run type-check`, `mypy`) -- You SHOULD review test output to confirm behavioral assertions passed -- You MAY need to adjust imports or setup code if tests fail - -**Installing Dependencies:** -- You MUST attempt to install missing dependencies when tests fail due to import errors -- You SHOULD check the project's dependency manifest (`pyproject.toml`, `package.json`, `Cargo.toml`, etc.) for optional dependency groups -- You SHOULD use the project's package manager to install dependencies (e.g., `pip install`, `npm install`, `cargo add`) -- For projects with optional extras, use the appropriate syntax (e.g., `pip install -e ".[extra]"` for Python, `npm install --save-dev` for Node.js) -- You SHOULD only fall back to mocking if the dependency cannot be installed (e.g., requires paid API keys, proprietary software) - -**Example of mocking external dependencies** - adapt syntax to project language: -```python -def test_custom_http_client(): - """Verify custom HTTP client is passed to the provider.""" - from unittest.mock import Mock, patch - - custom_client = Mock() - - with patch('strands.models.openai.OpenAI') as mock_openai: - from strands.models.openai import OpenAIModel - model = OpenAIModel(http_client=custom_client) - - # Verify the custom client was passed - mock_openai.assert_called_once() - call_kwargs = mock_openai.call_args[1] - assert call_kwargs.get('http_client') == custom_client -``` - -#### 4.3 Engineer Review Fallback - -When validation genuinely fails after documented attempts, use this fallback. Per **Principle 4**, you MUST still include the feature with a code sample. - -**Required proof before using this fallback:** -1. Created an actual test file (show the code in the validation comment) -2. Ran the test and received an actual error (show the error message) -3. Tried at least ONE alternative approach (Bedrock, mocking, simplified example) -4. Documented each attempt and its failure reason - -**Constraints:** -- You MUST NOT mark examples as needing validation without actually attempting validation first -- You MUST NOT use vague reasons like "complex setup required" - be specific about what you tried and what error you got -- You MUST show your test code and error messages in the Validation Comment -- You MUST try Bedrock for any feature that works with multiple model providers before giving up -- You MUST try mocking for provider-specific features before giving up -- You MUST document all validation attempts (successful AND failed) in the Validation Comment -- You MUST preserve the test file content to include in the GitHub issue comment (Step 6.1) -- You MUST note in the validation comment what specific behavior each test verifies -- You MAY delete temporary test files after capturing their content, as the environment is ephemeral - -**Process when validation genuinely fails:** -1. **Extract a code sample from the PR** - Use code from: - - The PR description's code examples - - Test files added in the PR - - The actual implementation (simplified for readability) - - Documentation updates in the PR -2. **Include the sample in the release notes** with a clear callout that it needs engineer validation -3. **Document the validation attempts and failures** in the Validation Comment (Step 6.1) - -**Format for unvalidated code examples:** -```markdown -### Feature Name - [PR#123](link) - -Description of the feature and its impact. - -\`\`\`python -# ⚠️ NEEDS ENGINEER VALIDATION -# Validation attempted: [describe test created and error received] -# Alternative attempts: [what else you tried and why it failed] - -# Code sample extracted from PR description/tests -from strands import Agent -from strands.models.openai import OpenAIModel - -model = OpenAIModel(http_client=custom_client) -agent = Agent(model=model) -\`\`\` -``` - -### 5. Release Notes Formatting - -#### 5.1 Format Major Features Section - -Create the Major Features section with concise descriptions and code examples. - -**Constraints:** -- You MUST create a section with heading: `## Major Features` -- You MUST create a subsection for each major feature using heading: `### Feature Name - [PR#123](link)` -- You MUST include the PR number and link in the feature heading -- You MUST write a concise description of 2-3 sentences that explains what the feature does and why it matters -- You MUST NOT use bullet points or lists in feature descriptions—use prose only -- You MUST NOT write lengthy multi-paragraph explanations -- You MUST include a code block demonstrating the feature using the project's programming language -- You MUST use proper syntax highlighting for the project's language -- You SHOULD keep code examples under 20 lines -- You SHOULD include inline comments in code examples only when necessary for clarity -- You MAY include multiple code examples if the feature has distinct use cases -- You MAY include a single closing sentence after the code example (e.g., documentation link or brief note) -- You MAY reference multiple PRs if a feature spans several PRs: `### Feature Name - [PR#123](link), [PR#124](link)` - -**Example format**: -```markdown -### Structured Output via Agentic Loop - [PR#943](https://github.com/org/repo/pull/943) - -Agents can now validate responses against predefined schemas with configurable retry behavior for non-conforming outputs. - -\`\`\`python -from strands import Agent -from pydantic import BaseModel - -class Response(BaseModel): - answer: str - -agent = Agent(output_schema=Response) -result = agent("What is 2+2?") -print(result.output.answer) -\`\`\` - -See the [Structured Output docs](https://docs.example.com/structured-output) for configuration options. -``` - -#### 5.2 Format Major Bug Fixes Section - -Create the Major Bug Fixes section highlighting critical fixes (if any exist). - -**Constraints:** -- You MUST create this section only if there are critical bug fixes -- You MUST create a section with heading: `## Major Bug Fixes` -- You MUST add a horizontal rule before this section: `---` -- You MUST format each bug fix as a bullet list item: `- **Fix Title** - [PR#123](link)` -- You MUST write a brief explanation (1-2 sentences) after each bullet that describes: - - What was broken - - What impact it had on users - - What is now fixed -- You SHOULD order fixes by severity or user impact -- You SHOULD keep descriptions concise but informative -- You MAY skip this section entirely if there are no major bug fixes - -**Example format**: -```markdown ---- - -## Major Bug Fixes - -- **Guardrails Redaction Fix** - [PR#1072](https://github.com/org/repo/pull/1072) - Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. - -- **Tool Result Block Redaction** - [PR#1080](https://github.com/org/repo/pull/1080) - Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. -``` - -#### 5.3 End with Separator - -Add a horizontal rule to separate your content from GitHub's auto-generated sections. - -**Constraints:** -- You MUST end your release notes with a horizontal rule: `---` -- This visually separates your curated content from GitHub's auto-generated "What's Changed" and "New Contributors" sections -- You MUST NOT include a "Full Changelog" link—GitHub adds this automatically - -### 6. Output Delivery - -Per **Principle 1**, all deliverables must be posted as GitHub issue comments. - -**Comment Structure**: Post exactly three comments on the GitHub issue: -1. **Validation Comment** (first): Contains all validation code for all features in one batched comment -2. **Release Notes Comment** (second): Contains the final formatted release notes -3. **Exclusions Comment** (third): Documents any features that were excluded and why - -This ordering allows reviewers to see the validation evidence, review the release notes, and understand any exclusion decisions. - -**Iteration Comments**: If the user requests changes after the initial comments are posted: -- Post additional validation comments for any re-validated code -- Post updated release notes as new comments (do not edit previous comments) -- This creates an audit trail of changes and validations - -#### 6.1 Post Validation Code Comment - -Batch all validation code into a single GitHub issue comment. - -**Constraints:** -- You MUST post ONE comment containing validation attempts for ALL Major Features -- You MUST show test code for EVERY feature - both successful and failed attempts -- You MUST NOT post separate comments for each feature's validation -- You MUST post this comment BEFORE the release notes comment -- You MUST include all test files created during validation (Step 4) in this single comment -- You MUST document what specific behavior each test verifies (not just "validates the code works") -- You MUST NOT reference local file paths—the ephemeral environment will be destroyed -- You MUST clearly label this comment as "Code Validation Tests" -- You SHOULD use collapsible `
` sections to organize validation code by feature -- You SHOULD include a brief description of what behavior is being verified for each test - -**Format:** -```markdown -## Code Validation Tests - -The following test code was used to validate the code examples in the release notes. - -
-✅ Validated: Feature Name 1 - -**Behavior verified:** This test confirms that the new `output_schema` parameter causes the agent to return a validated Pydantic model instance with the correct field types. - -\`\`\`python -[Full test file for feature 1 with behavioral assertions] -\`\`\` - -**Test output:** PASSED - -
- -
-⚠️ Could Not Validate: Feature Name 2 - -**Attempt 1: Direct test with mocked model** -\`\`\`python -[Test code that was attempted] -\`\`\` -**Error received:** -\`\`\` -[Actual error message from running the test] -\`\`\` - -**Attempt 2: Test with Bedrock** -\`\`\`python -[Alternative test code attempted] -\`\`\` -**Error received:** -\`\`\` -[Actual error message] -\`\`\` - -**Conclusion:** Could not validate because [specific reason based on actual errors]. Code sample in release notes extracted from PR description. - -
-``` - -#### 6.2 Post Release Notes Comment - -Post the formatted release notes as a single GitHub issue comment. - -**Constraints:** -- You MUST post ONE comment containing the complete release notes -- You MUST post this comment AFTER the validation comment -- You MUST use the `add_issue_comment` tool to post the comment -- You MUST include Major Features, Major Bug Fixes (if any), and a trailing separator (`---`) -- You MUST NOT expect users to access any local files—everything must be in the comment -- You SHOULD add a brief introductory line (e.g., "## Release Notes for v1.15.0") -- You MAY use markdown formatting in the comment -- If comment posting is deferred, continue with the workflow and note the deferred status - -#### 6.3 Post Exclusions Comment - -Document any features with unvalidated code samples and any other notable decisions. - -**Constraints:** -- You MUST post this comment as the FINAL comment on the GitHub issue -- You MUST include this comment if ANY of the following occurred: - - A Major Feature has an unvalidated code sample (marked for engineer review) - - A feature's scope or description was significantly different from the PR description - - You relied on review comments rather than the PR description to understand a feature -- You MUST clearly explain the reasoning for each unvalidated sample -- You SHOULD include this comment even if all code samples were validated, with a simple note: "All code samples were successfully validated. No engineer review required." -- You MUST NOT skip this comment—it provides critical transparency for reviewers - -**Format:** -```markdown -## Release Notes Review Notes - -The following items require attention during review: - -### ⚠️ Features with Unvalidated Code Samples - -These features have code samples extracted from PRs but could not be automatically validated. An engineer must verify these examples before publishing: - -- **PR#123 - Feature Title**: - - Code source: PR description / test files / implementation - - Validation attempted: [what you tried] - - Failure reason: [why it failed, e.g., "requires OpenAI API credentials", "complex multi-service integration"] - - Action needed: Engineer should verify the code sample works as shown - -### Description vs. Implementation Discrepancies -- **PR#101 - Feature Title**: PR description stated [X] but review comments and final implementation show [Y]. Release notes reflect the actual merged behavior. -``` - -#### 6.4 Handle User Feedback on Release Notes - -When the user requests changes to the release notes after they have been posted, re-validate as needed. - -**Constraints:** -- You MUST re-run validation (Step 4) when the user requests changes that affect code examples: - - Modified code snippets - - New code examples for features that previously had none - - Replacement examples for features -- You MUST perform full extraction (Step 3) and validation (Step 4) when the user requests: - - Adding a new feature to the release notes that wasn't previously included - - Promoting a bug fix to include a code example -- You MUST NOT make changes to code examples without re-validating them -- You MUST post updated validation code as a new comment when re-validation occurs -- You MUST post the revised release notes as a new comment (do not edit previous comments) -- You SHOULD note in the updated release notes comment what changed from the previous version -- You MAY skip re-validation only for changes that do not affect code: - - Wording changes to descriptions - - Fixing typos - - Reordering features - - Removing features (no validation needed for removal) - -## Examples - -### Example 1: Complete Release Notes - -```markdown -## Major Features - -### Managed MCP Connections - [PR#895](https://github.com/org/repo/pull/895) - -MCP Connections via ToolProviders allow the Agent to manage connection lifecycles automatically, eliminating the need for manual context managers. This experimental interface simplifies MCP tool integration significantly. - -\`\`\`python -from strands import Agent -from strands.tools import MCPToolProvider - -provider = MCPToolProvider(server_config) -agent = Agent(tools=[provider]) -result = agent("Use the MCP tools") -\`\`\` - -See the [MCP docs](https://docs.example.com/mcp) for details. - -### Custom HTTP Client Support - [PR#1366](https://github.com/org/repo/pull/1366) - -OpenAI model provider now accepts a custom HTTP client, enabling proxy configuration, custom timeouts, and request logging. - -\`\`\`python -# ⚠️ NEEDS ENGINEER VALIDATION -# Validation attempted: mocked OpenAI client, received import error -# Alternative attempts: Bedrock (not applicable - OpenAI-specific) - -from strands.models.openai import OpenAIModel -import httpx - -custom_client = httpx.Client(proxy="http://proxy.example.com:8080") -model = OpenAIModel(client_args={"http_client": custom_client}) -\`\`\` - ---- - -## Major Bug Fixes - -- **Guardrails Redaction Fix** - [PR#1072](https://github.com/strands-agents/sdk-python/pull/1072) - Fixed input/output message redaction when `guardrails_trace="enabled_full"`, ensuring sensitive data is properly protected in traces. - -- **Tool Result Block Redaction** - [PR#1080](https://github.com/strands-agents/sdk-python/pull/1080) - Properly redact tool result blocks to prevent conversation corruption when using content filtering or PII redaction. - -- **Orphaned Tool Use Fix** - [PR#1123](https://github.com/strands-agents/sdk-python/pull/1123) - Fixed broken conversations caused by orphaned `toolUse` blocks, improving reliability when tools fail or are interrupted. - ---- -``` - -Note: The trailing `---` separates your content from GitHub's auto-generated "What's Changed" and "New Contributors" sections that follow. - -## Troubleshooting - -### Missing or Invalid Git References - -If one or both git references are missing or invalid: -1. Verify the references exist in the repository using `git ls-remote --tags` or `git ls-remote --heads` -2. Check if the user provided branch names vs. tag names -3. Leave a comment on the issue explaining which reference is invalid -4. Use the handoff_to_user tool to request clarification - -### GitHub API Rate Limiting - -If you encounter GitHub API rate limit errors: -1. Check the rate limit status using the `X-RateLimit-Remaining` header -2. If rate limited, note the `X-RateLimit-Reset` timestamp -3. Consider reducing the number of API calls by batching requests -4. Leave a comment on the issue explaining the rate limit issue -5. Use the handoff_to_user tool to inform the user - -### Code Validation Failures - -Follow the validation workflow in Section 4.2. If all attempts fail, use the engineer review fallback per Section 4.3. Per **Principle 4**, always include a code sample. - -### Large PR Sets (>100 PRs) - -If there are many PRs between the references: -1. Consider whether the git references are correct (e.g., not comparing main to an ancient tag) -2. Focus categorization efforts on the most significant changes -3. Be more selective about what qualifies as a "Major Feature" or "Major Bug Fix" - -### No PRs Found Between References - -If no PRs are found: -1. Verify that the base and head references are in the correct order (base should be older) -2. Check if the references are the same -3. Verify that there are actually commits between the references -4. Check if a release exists that might have the PR list -5. Leave a comment on the issue explaining the situation -6. Use the handoff_to_user tool to request clarification - -### Release Parsing Issues - -If the release body cannot be parsed correctly: -1. Check if the format matches GitHub's standard auto-generated format -2. Look for the "What's Changed" heading and bullet list format: `* PR title by @author in URL` -3. If parsing fails, fall back to querying the GitHub API directly (Step 1.3) -4. Note in the categorization comment that you fell back to API queries - -### Deferred Operations - -When GitHub tools or git operations are deferred (GITHUB_WRITE=false): -- Continue with the workflow as if the operation succeeded -- Note the deferred status in your progress tracking -- The operations will be executed after agent completion -- Do not retry or attempt alternative approaches for deferred operations - -### Stale PR Descriptions - -Per **Principle 2**: Review PR comments for context on what changed, examine merged code (especially test files), and use test files as the authoritative source for code examples. - -## Desired Outcome - -* Focused release notes highlighting Major Features and Major Bug Fixes with concise descriptions (2-3 sentences, no bullet points) -* Code examples for ALL major features - either validated or marked for engineer review -* Validated code examples have passing behavioral tests -* Unvalidated code examples are clearly marked with the engineer validation warning and extracted from PR sources -* Well-formatted markdown that renders properly on GitHub -* Release notes posted as a comment on the GitHub issue for review -* Review notes comment documenting any features with unvalidated code samples that need engineer attention - -**Important**: Your generated release notes will be prepended to GitHub's auto-generated release notes. GitHub automatically generates: -- "What's Changed" section listing all PRs with authors and links -- "New Contributors" section acknowledging first-time contributors -- "Full Changelog" comparison link - -You should NOT include these sections—focus exclusively on Major Features and Major Bug Fixes that benefit from detailed descriptions and code examples. Minor changes (refactors, docs, tests, chores, etc.) will be covered by GitHub's automatic changelog. \ No newline at end of file diff --git a/.github/scripts/javascript/process-input.cjs b/.github/scripts/javascript/process-input.cjs deleted file mode 100644 index 395e37b64..000000000 --- a/.github/scripts/javascript/process-input.cjs +++ /dev/null @@ -1,141 +0,0 @@ -// This file assumes that its run from an environment that already has github and core imported: -// const github = require('@actions/github'); -// const core = require('@actions/core'); - -const fs = require('fs'); - -async function getIssueInfo(github, context, inputs) { - const issueId = context.eventName === 'workflow_dispatch' - ? inputs.issue_id - : context.payload.issue.number.toString(); - const commentBody = context.payload.comment?.body || ''; - const command = context.eventName === 'workflow_dispatch' - ? inputs.command - : (commentBody.startsWith('/strands') ? commentBody.slice('/strands'.length).trim() : ''); - - console.log(`Event: ${context.eventName}, Issue ID: ${issueId}, Command: "${command}"`); - - const issue = await github.rest.issues.get({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: issueId - }); - - return { issueId, command, issue }; -} - -async function determineBranch(github, context, issueId, mode, isPullRequest) { - let branchName = 'main'; - - if (mode === 'implementer' && !isPullRequest) { - branchName = `agent-tasks/${issueId}`; - - const mainRef = await github.rest.git.getRef({ - owner: context.repo.owner, - repo: context.repo.repo, - ref: 'heads/main' - }); - - try { - await github.rest.git.createRef({ - owner: context.repo.owner, - repo: context.repo.repo, - ref: `refs/heads/${branchName}`, - sha: mainRef.data.object.sha - }); - console.log(`Created branch ${branchName}`); - } catch (error) { - if (error.status === 422 || error.message?.includes('already exists')) { - console.log(`Branch ${branchName} already exists`); - } else { - throw error; - } - } - } else if (isPullRequest) { - const pr = await github.rest.pulls.get({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: issueId - }); - branchName = pr.data.head.ref; - } - - return branchName; -} - -function buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs) { - const sessionId = inputs.session_id || (mode === 'implementer' - ? `${mode}-${branchName}`.replace(/[\/\\]/g, '-') - : `${mode}-${issueId}`); - - const scriptFiles = { - 'implementer': '.github/agent-sops/task-implementer.sop.md', - 'refiner': '.github/agent-sops/task-refiner.sop.md', - 'release-notes': '.github/agent-sops/task-release-notes.sop.md' - }; - - const scriptFile = scriptFiles[mode] || scriptFiles['refiner']; - const systemPrompt = fs.readFileSync(scriptFile, 'utf8'); - - // Extract the user's feedback/instructions after the mode keyword - // e.g., "release-notes Move #123 to Major Features" -> "Move #123 to Major Features" - const modeKeywords = { - 'release-notes': /^(?:release-notes|release notes)\s*/i, - 'implementer': /^implement\s*/i, - 'refiner': /^refine\s*/i - }; - - const modePattern = modeKeywords[mode]; - const userFeedback = modePattern ? command.replace(modePattern, '').trim() : command.trim(); - - let prompt = (isPullRequest) - ? 'The pull request id is:' - : 'The issue id is:'; - prompt += `${issueId}\n`; - - // If there's any user feedback beyond the command keyword, include it as the main instruction, - // otherwise default to "review and continue" - prompt += userFeedback || 'review and continue'; - - return { sessionId, systemPrompt, prompt }; -} - -module.exports = async (context, github, core, inputs) => { - try { - const { issueId, command, issue } = await getIssueInfo(github, context, inputs); - - const isPullRequest = !!issue.data.pull_request; - - // Determine mode based on explicit command first, then context - let mode; - if (command.startsWith('release-notes') || command.startsWith('release notes')) { - mode = 'release-notes'; - } else if (command.startsWith('implement')) { - mode = 'implementer'; - } else if (command.startsWith('refine')) { - mode = 'refiner'; - } else { - // Default behavior when no explicit command: PR -> implementer, Issue -> refiner - mode = isPullRequest ? 'implementer' : 'refiner'; - } - console.log(`Is PR: ${isPullRequest}, Command: "${command}", Mode: ${mode}`); - - const branchName = await determineBranch(github, context, issueId, mode, isPullRequest); - console.log(`Building prompts - mode: ${mode}, issue: ${issueId}, is PR: ${isPullRequest}`); - - const { sessionId, systemPrompt, prompt } = buildPrompts(mode, issueId, isPullRequest, command, branchName, inputs); - - console.log(`Session ID: ${sessionId}`); - console.log(`Task prompt: "${prompt}"`); - - core.setOutput('branch_name', branchName); - core.setOutput('session_id', sessionId); - core.setOutput('system_prompt', systemPrompt); - core.setOutput('prompt', prompt); - - } catch (error) { - const errorMsg = `Failed: ${error.message}`; - console.error(errorMsg); - core.setFailed(errorMsg); - } -}; diff --git a/.github/scripts/python/agent_runner.py b/.github/scripts/python/agent_runner.py deleted file mode 100644 index 1f772241c..000000000 --- a/.github/scripts/python/agent_runner.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -""" -Strands GitHub Agent Runner -A portable agent runner for use in GitHub Actions across different repositories. -""" - -import json -import os -import sys -from typing import Any - -from strands import Agent -from strands.agent.conversation_manager import SlidingWindowConversationManager -from strands.session import S3SessionManager -from strands.models.bedrock import BedrockModel -from botocore.config import Config - -from strands_tools import http_request, shell - -# Import local GitHub tools we need -from github_tools import ( - add_issue_comment, - create_issue, - create_pull_request, - get_issue, - get_issue_comments, - get_pull_request, - get_pr_review_and_comments, - list_issues, - list_pull_requests, - reply_to_review_comment, - update_issue, - update_pull_request, -) - -# Import local tools we need -from handoff_to_user import handoff_to_user -from notebook import notebook -from str_replace_based_edit_tool import str_replace_based_edit_tool - -# Strands configuration constants -STRANDS_MODEL_ID = "global.anthropic.claude-opus-4-5-20251101-v1:0" -STRANDS_MAX_TOKENS = 64000 -STRANDS_BUDGET_TOKENS = 8000 -STRANDS_REGION = "us-west-2" - -# Default values for environment variables used only in this file -DEFAULT_SYSTEM_PROMPT = "You are an autonomous GitHub agent powered by Strands Agents SDK." - -def _get_all_tools() -> list[Any]: - return [ - # File editing - str_replace_based_edit_tool, - - # System tools - shell, - http_request, - - # GitHub issue tools - create_issue, - get_issue, - update_issue, - list_issues, - add_issue_comment, - get_issue_comments, - - # GitHub PR tools - create_pull_request, - get_pull_request, - update_pull_request, - list_pull_requests, - get_pr_review_and_comments, - reply_to_review_comment, - - # Agent tools - notebook, - handoff_to_user, - ] - - -def run_agent(query: str): - """Run the agent with the provided query.""" - try: - # Get tools and create model - tools = _get_all_tools() - - # Create Bedrock model with inlined configuration - additional_request_fields = {} - additional_request_fields["anthropic_beta"] = ["interleaved-thinking-2025-05-14"] - - additional_request_fields["thinking"] = { - "type": "enabled", - "budget_tokens": STRANDS_BUDGET_TOKENS - } - - model = BedrockModel( - model_id=STRANDS_MODEL_ID, - max_tokens=STRANDS_MAX_TOKENS, - region_name=STRANDS_REGION, - boto_client_config=Config( - read_timeout=900, - connect_timeout=900, - retries={"max_attempts": 3, "mode": "adaptive"}, - ), - additional_request_fields=additional_request_fields, - cache_prompt="default", - cache_tools="default", - ) - system_prompt = os.getenv("INPUT_SYSTEM_PROMPT", DEFAULT_SYSTEM_PROMPT) - session_id = os.getenv("SESSION_ID") - s3_bucket = os.getenv("S3_SESSION_BUCKET") - s3_prefix = os.getenv("GITHUB_REPOSITORY", "") - - if s3_bucket and session_id: - print(f"🤖 Using session manager with session ID: {session_id}") - session_manager = S3SessionManager( - session_id=session_id, - bucket=s3_bucket, - prefix=s3_prefix, - ) - else: - raise ValueError("Both SESSION_ID and S3_SESSION_BUCKET must be set") - - # Create agent - agent = Agent( - model=model, - system_prompt=system_prompt, - tools=tools, - session_manager=session_manager, - ) - - print("Processing user query...") - result = agent(query) - - print(f"\n\nAgent Result 🤖\nStop Reason: {result.stop_reason}\nMessage: {json.dumps(result.message, indent=2)}") - except Exception as e: - error_msg = f"❌ Agent execution failed: {e}" - print(error_msg) - raise e - - -def main() -> None: - """Main entry point for the agent runner.""" - try: - # Prefer INPUT_TASK env var (avoids shell escaping issues), fall back to CLI args - task = os.getenv("INPUT_TASK", "").strip() - if not task and len(sys.argv) > 1: - task = " ".join(sys.argv[1:]).strip() - if not task: - raise ValueError("Task is required (via INPUT_TASK env var or CLI argument)") - print(f"🤖 Running agent with task: {task}") - - run_agent(task) - - except Exception as e: - error_msg = f"Fatal error: {e}" - print(error_msg) - - sys.exit(1) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/.github/scripts/python/github_tools.py b/.github/scripts/python/github_tools.py deleted file mode 100644 index 8826b4611..000000000 --- a/.github/scripts/python/github_tools.py +++ /dev/null @@ -1,843 +0,0 @@ -"""GitHub repository management tool for Strands Agents. - -This module provides comprehensive GitHub repository operations including issues, -pull requests, comments, and repository management. Supports full GitHub API -integration with rich console output and error handling. - -Key Features: -1. List and manage issues and pull requests -2. Add comments to issues and PRs -3. Create, update, and manage issues -4. Create, update, and manage pull requests -5. Get detailed information for specific issues/PRs -6. Manage PR reviews and review comments -7. Get issue and PR comment threads -8. Check GitHub token permissions for repositories -9. Rich console output with formatted tables -10. Automatic fallback to GITHUB_REPOSITORY environment variable - -Usage Examples: -```python -from strands import Agent -from tools.github_tools import list_issues, add_comment, create_issue, _check_token_permissions - -agent = Agent(tools=[list_issues, add_comment, create_issue]) - -# Check token permissions -has_write = _check_token_permissions("ghp_token123", "owner/repo") - -# List open issues in repository -result = agent.tool.list_issues(state="open", repo="owner/repo") - -# Add comment to an issue -result = agent.tool.add_comment( - issue_number=42, - comment_text="Great idea! I'll work on this.", - repo="owner/repo" -) - -# Create a new issue -result = agent.tool.create_issue( - title="Bug: Application crashes on startup", - body="Description of the issue with steps to reproduce...", - repo="owner/repo" -) - -# List pull requests -result = agent.tool.list_pull_requests(state="open", repo="owner/repo") - -# Get specific issue details -result = agent.tool.get_issue(issue_number=123, repo="owner/repo") - -# Update pull request -result = agent.tool.update_pull_request( - pr_number=456, - title="Updated PR title", - body="Updated description", - repo="owner/repo" -) -``` -""" - -import os -import traceback -from datetime import datetime -from functools import wraps -import json -from typing import Any, TypedDict -from urllib.parse import urlencode, quote - -import requests -from rich import box -from rich.markup import escape -from rich.panel import Panel -from rich.table import Table -from strands import tool -from strands_tools.utils import console_util - -console = console_util.create() - - -class GitHubOperation(TypedDict): - """Type definition for GitHub operation records in JSONL files.""" - timestamp: str - function: str - args: list[Any] - kwargs: dict[str, Any] - - -def log_inputs(func): - """Decorator to log function inputs in a blue panel.""" - @wraps(func) - def wrapper(*args, **kwargs): - # Get function name and format it nicely - func_name = func.__name__.replace('_', ' ').title() - - # Format parameters - params = [] - for k, v in kwargs.items(): - if isinstance(v, str) and len(v) > 50: - params.append(f"{k}='{v[:50]}...'") - else: - params.append(f"{k}='{v}'") - - console.print(Panel(", ".join(params), title=f"[bold blue]{func_name}", border_style="blue")) - return func(*args, **kwargs) - return wrapper - - -def _github_request( - method: str, endpoint: str, repo: str | None = None, data: dict | None = None, params: dict | None = None, should_raise: bool = False -) -> dict[str, Any] | str: - """Make a GitHub API request with common error handling. - - Args: - method: HTTP method (GET, POST, PATCH, etc.) - endpoint: API endpoint path (e.g., "pulls", "issues/123") - repo: Repository in "owner/repo" format - data: JSON data for request body - params: Query parameters for the request - - Returns: - Response JSON or error string - """ - if repo is None: - repo = os.environ.get("GITHUB_REPOSITORY") - if not repo: - return "Error: GITHUB_REPOSITORY environment variable not found" - - token = os.environ.get("GITHUB_TOKEN", "") - if not token: - return "Error: GITHUB_TOKEN environment variable not found" - - url = f"https://api.github.com/repos/{repo}/{endpoint}" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github.v3+json", - } - - try: - if method.upper() == "GET": - response = requests.get(url, headers=headers, params=params, timeout=30) - elif method.upper() == "POST": - response = requests.post(url, headers=headers, json=data, params=params, timeout=30) - else: - response = requests.request(method, url, headers=headers, json=data, params=params, timeout=30) - response.raise_for_status() - return response.json() # type: ignore[no-any-return] - except Exception as e: - if should_raise: - raise e - return f"Error {e!s}" - - -def check_should_call_write_api_or_record(func): - """Decorator that checks if a write api should be called, or if the tool should record to JSONL.""" - @wraps(func) - def wrapper(*args, **kwargs): - try: - if not _should_call_write_api(): - # Record the tool request to JSONL file - record_entry: GitHubOperation = { - "timestamp": datetime.utcnow().isoformat() + "Z", - "function": func.__name__, - "args": args, - "kwargs": kwargs - } - - os.makedirs(".artifact", exist_ok=True) - with open(".artifact/write_operations.jsonl", "a") as f: - f.write(json.dumps(record_entry) + "\n") - - # Generate and return deferred message - params = dict(kwargs) - if args: - # Map positional args to parameter names from function signature - import inspect - sig = inspect.signature(func) - param_names = list(sig.parameters.keys()) - for i, arg in enumerate(args): - if i < len(param_names): - params[param_names[i]] = arg - - deferred_msg = _generate_deferred_message(func.__name__, params) - console.print(Panel(escape(deferred_msg), title="[bold yellow]Operation Deferred", border_style="yellow")) - return deferred_msg - except Exception as e: - error_msg = f"Error checking permissions: {e!s}" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - - return func(*args, **kwargs) - return wrapper - - -def _generate_deferred_message(operation_name: str, params: dict[str, Any]) -> str: - """Generate a consistent deferred message for write operations. - - Args: - operation_name: Name of the operation being deferred - params: Parameters that would have been used for the operation - - Returns: - Formatted deferred message string - """ - if not params: - return f"Operation deferred: {operation_name}" - - # Format parameters, truncating long values - param_strs = [] - for key, value in params.items(): - if isinstance(value, str) and len(value) > 50: - param_strs.append(f"{key}='{value[:50]}...'") - elif isinstance(value, str): - param_strs.append(f"{key}='{value}'") - else: - param_strs.append(f"{key}={value}") - - return f"Operation deferred: {operation_name} - {', '.join(param_strs)}" - - -def _should_call_write_api() -> bool: - """Checks if GITHUB_WRITE environment variable is set to true. - - Returns: - bool: True if GITHUB_WRITE is set to 'true', False otherwise - """ - return os.environ.get("GITHUB_WRITE", "").lower() == "true" - - -# ============================================================================= -# WRITE FUNCTIONS (Functions that modify GitHub resources) -# ============================================================================= - -@tool -@log_inputs -@check_should_call_write_api_or_record -def create_issue(title: str, body: str = "", repo: str | None = None) -> str: - """Creates a new issue in the specified repository. - - Args: - title: The issue title - body: The issue body (optional) - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Result of the operation - """ - result = _github_request("POST", "issues", repo, {"title": title, "body": body}) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - message = f"Issue created: #{result['number']} - {result['html_url']}" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - -@tool -@log_inputs -@check_should_call_write_api_or_record -def update_issue( - issue_number: int, - title: str | None = None, - body: str | None = None, - state: str | None = None, - repo: str | None = None, -) -> str: - """Updates an issue's title, body, or state. - - Args: - issue_number: The issue number - title: New title (optional) - body: New body (optional) - state: New state - "open" or "closed" (optional) - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Result of the operation - """ - data = {} - if title is not None: - data["title"] = title - if body is not None: - data["body"] = body - if state is not None: - data["state"] = state - - if not data: - error_msg = "Error: At least one field (title, body, or state) must be provided" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - - result = _github_request("PATCH", f"issues/{issue_number}", repo, data) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - message = f"Issue updated: #{result['number']} - {result['html_url']}" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - -@tool -@log_inputs -@check_should_call_write_api_or_record -def add_issue_comment(issue_number: int, comment_text: str, repo: str | None = None) -> str: - """Adds a comment to an issue or pull request in the specified repository or GITHUB_REPOSITORY environment variable. - - Args: - issue_number: The issue or PR number to comment on - comment_text: The comment text - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Result of the operation - """ - result = _github_request("POST", f"issues/{issue_number}/comments", repo, {"body": comment_text}) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - message = f"Comment added successfully: {result['html_url']} (created: {result['created_at']})" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - -@tool -@log_inputs -@check_should_call_write_api_or_record -def create_pull_request(title: str, head: str, base: str, body: str = "", repo: str | None = None, fallback_issue_id: int | None = None) -> str: - """Creates a new pull request, or optionally comments on the fallback_issue_id for a link to create a pull request. - - Args: - title: The PR title - head: The branch containing changes - base: The branch to merge into - body: The PR body (optional) - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - fallback_issue_id: Issue ID to comment on if PR creation fails with an error (optional) - - Returns: - Result of the operation - """ - try: - result = _github_request( - "POST", - "pulls", - repo, - {"title": title, "head": head, "base": base, "body": body}, - should_raise=True - ) - - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - - message = f"Pull request created: #{result['number']} - {result['html_url']}" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - except Exception as e: - if fallback_issue_id is not None: - agent_message = "Failed to create pull request, commenting on issue instead." - console.print(Panel(escape(agent_message), title="[bold yellow]Fallback", border_style="yellow")) - repo_name = repo or os.environ.get("GITHUB_REPOSITORY", "") - query_params = urlencode({ - 'quick_pull': '1', - 'title': title, - 'body': body - }, quote_via=quote) - pr_link = f"https://github.com/{repo_name}/compare/{base}...{head}?{query_params}" - fallback_comment = f"Unable to create pull request via API. You can create it manually by clicking [here]({pr_link})." - add_issue_comment(fallback_issue_id, fallback_comment, repo) - return f"Unable to create pull request via API - posted a manual creation link as a comment on issue #{fallback_issue_id}" - else: - error_msg = f"Error: {e!s}" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - - -@tool -@log_inputs -@check_should_call_write_api_or_record -def update_pull_request( - pr_number: int, - title: str | None = None, - body: str | None = None, - base: str | None = None, - repo: str | None = None, -) -> str: - """Updates a pull request's title, body, or base branch. - - Args: - pr_number: The pull request number - title: New title (optional) - body: New body (optional) - base: New base branch (optional) - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Result of the operation - """ - data = {} - if title is not None: - data["title"] = title - if body is not None: - data["body"] = body - if base is not None: - data["base"] = base - - if not data: - error_msg = "Error: At least one field (title, body, or base) must be provided" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - - result = _github_request("PATCH", f"pulls/{pr_number}", repo, data) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - message = f"Pull request updated: #{result['number']} - {result['html_url']}" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - -@tool -@log_inputs -@check_should_call_write_api_or_record -def reply_to_review_comment(pr_number: int, comment_id: int, reply_text: str, repo: str | None = None) -> str: - """Replies to a pull request review comment. - - Args: - pr_number: The pull request number - comment_id: The review comment ID to reply to - reply_text: The reply text - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Result of the operation - """ - result = _github_request("POST", f"pulls/{pr_number}/comments/{comment_id}/replies", repo, {"body": reply_text}) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - message = f"Reply added to review comment: {result['html_url']}" - reply_details = f"Reply: {reply_text}\nURL: {result['html_url']}" - console.print(Panel(escape(reply_details), title="[bold green]✅ Reply Added", border_style="green")) - return message - - -# ============================================================================= -# READ FUNCTIONS (Functions that only read GitHub resources) -# ============================================================================= - -@tool -@log_inputs -def get_issue(issue_number: int, repo: str | None = None) -> str: - """Gets details of a specific issue. - - Args: - issue_number: The issue number - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Issue details - """ - result = _github_request("GET", f"issues/{issue_number}", repo) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - details = ( - f"#{result['number']} - {result['title']}\n" - f"State: {result['state']}\n" - f"Author: {result['user']['login']}\n" - f"URL: {result['html_url']}\n\n{result['body']}" - ) - console.print( - Panel( - escape(details), - title=f"[bold green]📋 Issue #{result['number']}", - border_style="blue", - ) - ) - return details - - -@tool -@log_inputs -def list_issues(state: str = "open", repo: str | None = None) -> str: - """Lists issues from the specified GitHub repository or GITHUB_REPOSITORY environment variable. - - Args: - state: Filter issues by state: "open", "closed", or "all" (default: "open") - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - String representation of the issues - """ - result = _github_request("GET", "issues", repo, params={"state": state}) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - # Filter out pull requests from issues list - issues = [issue for issue in result if "pull_request" not in issue] - if not issues: - message = f"No {state} issues found in {repo or os.environ.get('GITHUB_REPOSITORY')}" - console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) - return message - - table = Table(title=f"🐛 Issues ({state})", box=box.DOUBLE) - table.add_column("Issue #", style="cyan") - table.add_column("Title", style="white") - table.add_column("Author", style="green") - table.add_column("URL", style="blue") - - for issue in issues: - table.add_row( - f"#{issue['number']}", # type: ignore[index] - issue["title"], # type: ignore[index] - issue["user"]["login"], # type: ignore[index] - issue["html_url"], # type: ignore[index] - ) - - console.print(table) - - output = f"Issues ({state}) in {repo or os.environ.get('GITHUB_REPOSITORY')}:\n" - for issue in issues: - output += f"#{issue['number']} - {issue['title']} by {issue['user']['login']} - {issue['html_url']}\n" # type: ignore[index] - return output - - -@tool -@log_inputs -def get_issue_comments(issue_number: int, repo: str | None = None, since: str | None = None) -> str: - """Gets all comments for a specific issue. - - Args: - issue_number: The issue number - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - since: ISO 8601 timestamp to filter comments updated after this date (optional) - - Returns: - List of comments - """ - params = {"since": since} if since else None - result = _github_request("GET", f"issues/{issue_number}/comments", repo, params=params) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - if not result: - message = f"No comments found for issue #{issue_number}" + (f" updated after {since}" if since else "") - console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) - return message - - output = f"Comments for issue #{issue_number}:\n" - for comment in result: - output += f"{comment['user']['login']} - updated: {comment['updated_at']}\n{comment['body']}\n\n" # type: ignore[index] - - console.print(Panel(escape(output), title=f"[bold green]💬 Issue #{issue_number} Comments", border_style="blue")) - return output - - -@tool -@log_inputs -def get_pull_request(pr_number: int, repo: str | None = None) -> str: - """Gets details of a specific pull request. - - Args: - pr_number: The pull request number - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - Pull request details - """ - result = _github_request("GET", f"pulls/{pr_number}", repo) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - details = ( - f"#{result['number']} - {result['title']}\n" - f"State: {result['state']}\n" - f"Author: {result['user']['login']}\n" - f"Head: {result['head']['ref']} -> Base: {result['base']['ref']}\n" - f"URL: {result['html_url']}\n\n{result['body']}" - ) - console.print( - Panel( - escape(details), - title=f"[bold green]🔀 PR #{result['number']}", - border_style="blue", - ) - ) - return details - - -@tool -@log_inputs -def list_pull_requests(state: str = "open", repo: str | None = None) -> str: - """Lists pull requests from the specified GitHub repository or GITHUB_REPOSITORY environment variable. - - Args: - state: Filter PRs by state: "open", "closed", or "all" (default: "open") - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - - Returns: - String representation of the pull requests - """ - result = _github_request("GET", "pulls", repo, params={"state": state}) - if isinstance(result, str): - console.print(Panel(escape(result), title="[bold red]Error", border_style="red")) - return result - - if not result: - message = f"No {state} pull requests found in {repo or os.environ.get('GITHUB_REPOSITORY')}" - console.print(Panel(escape(message), title="[bold yellow]Info", border_style="yellow")) - return message - - table = Table(title=f"🔀 Pull Requests ({state})", box=box.DOUBLE) - table.add_column("PR #", style="cyan") - table.add_column("Title", style="white") - table.add_column("Author", style="green") - table.add_column("URL", style="blue") - - for pr in result: - table.add_row(f"#{pr['number']}", pr["title"], pr["user"]["login"], pr["html_url"]) # type: ignore[index] - - console.print(table) - - output = f"Pull Requests ({state}) in {repo or os.environ.get('GITHUB_REPOSITORY')}:\n" - for pr in result: - output += f"#{pr['number']} - {pr['title']} by {pr['user']['login']} - {pr['html_url']}\n" # type: ignore[index] - return output - - -@tool -@log_inputs -def get_pr_review_and_comments(pr_number: int, show_resolved: bool = False, repo: str | None = None, since: str | None = None) -> str: - """Gets all review threads and comments for a PR. - - Args: - pr_number: The pull request number - repo: GitHub repository in the format "owner/repo" (optional; falls back to env var) - show_resolved: Whether to include resolved review threads (default: False) - since: ISO 8601 timestamp to filter comments/threads updated after this date (optional) - - Returns: - Formatted review threads and comments - """ - if repo is None: - repo = os.environ.get("GITHUB_REPOSITORY") - if not repo: - return "Error: GITHUB_REPOSITORY environment variable not found" - - token = os.environ.get("GITHUB_TOKEN", "") - if not token: - return "Error: GITHUB_TOKEN environment variable not found" - - owner, repo_name = repo.split("/") - - query = """ - query($owner: String!, $name: String!, $number: Int!) { - repository(owner: $owner, name: $name) { - pullRequest(number: $number) { - reviewThreads(first: 100) { - nodes { - isResolved - comments(first: 100) { - nodes { - id - fullDatabaseId - author { login } - body - updatedAt - path - line - startLine - diffHunk - replyTo { id } - pullRequestReview { - id - body - author { login } - updatedAt - } - } - } - } - } - comments(first: 100) { - nodes { - author { login } - body - updatedAt - } - } - } - } - } - """ - - variables = {"owner": owner, "name": repo_name, "number": pr_number} - - try: - response = requests.post( - "https://api.github.com/graphql", - headers={"Authorization": f"Bearer {token}"}, - json={"query": query, "variables": variables}, - timeout=30 - ) - response.raise_for_status() - data = response.json() - - if "errors" in data: - return f"GraphQL Error: {data['errors']}" - - pr_data = data["data"]["repository"]["pullRequest"] - - # Filter by since if provided - if since: - cutoff = datetime.fromisoformat(since.replace('Z', '+00:00')) - - # Filter review threads - if any comment in thread is newer, include entire thread - filtered_threads = [] - for thread in pr_data["reviewThreads"]["nodes"]: - has_newer_comment = any(datetime.fromisoformat(c['updatedAt'].replace('Z', '+00:00')) > cutoff - for c in thread["comments"]["nodes"]) - if has_newer_comment: - filtered_threads.append(thread) - pr_data["reviewThreads"]["nodes"] = filtered_threads - - # Filter general comments - pr_data["comments"]["nodes"] = [c for c in pr_data["comments"]["nodes"] - if datetime.fromisoformat(c['updatedAt'].replace('Z', '+00:00')) > cutoff] - - output = f"Review threads and comments for PR #{pr_number}:\n\n" - - # Group review threads by review ID - review_threads = {} - for thread in pr_data["reviewThreads"]["nodes"]: - if not show_resolved and thread["isResolved"]: - continue - - if thread["comments"]["nodes"]: - first_comment = thread["comments"]["nodes"][0] - review_id = first_comment.get("pullRequestReview", {}).get("id", "N/A") - - if review_id not in review_threads: - review_threads[review_id] = { - "review_data": first_comment.get("pullRequestReview", {}), - "threads": [] - } - - review_threads[review_id]["threads"].append(thread) - - # Display grouped review threads - for review_id, review_info in review_threads.items(): - review_data = review_info['review_data'] - output += f"📝 Review [Review ID: {review_id}]\n" - - # Always show review author and timestamps - if review_data.get('author'): - output += f" 👤 Review by {review_data['author']['login']} (updated: {review_data['updatedAt']})\n" - - # Show top-level review comment if it exists - if review_data.get('body'): - output += f" 📋 Review Comment:\n" - output += f" {review_data['body']}\n" - output += "\n" - - # Show all threads for this review - for thread in review_info["threads"]: - first_comment = thread["comments"]["nodes"][0] - line_info = f":{first_comment['line']}" if first_comment.get('line') else " (Comment on file)" - status = "✅ RESOLVED" if thread["isResolved"] else "🔄 OPEN" - - output += f" 📍 Thread ({status}): {first_comment['path']}{line_info}\n" - - # Show code context right after thread header - if first_comment.get('diffHunk') and first_comment.get('line'): - diff_lines = first_comment['diffHunk'].split('\n') - current_new_line = 0 - target_line = first_comment['line'] - start_line = first_comment.get('startLine') or target_line - - output += f" Code context (lines {start_line}-{target_line}):\n" - - for diff_line in diff_lines: - if diff_line.startswith('@@'): - parts = diff_line.split(' ') - if len(parts) >= 3: - new_start = parts[2].split(',')[0][1:] - current_new_line = int(new_start) - 1 - elif diff_line.startswith('+'): - current_new_line += 1 - if start_line <= current_new_line <= target_line: - output += f" +{current_new_line}: {diff_line[1:]}\n" - elif diff_line.startswith('-'): - pass - elif diff_line.startswith(' '): - current_new_line += 1 - if start_line <= current_new_line <= target_line: - output += f" {current_new_line}: {diff_line[1:]}\n" - output += "\n" - - # Group comments by reply relationships - comments = thread["comments"]["nodes"] - root_comments = [c for c in comments if not c.get('replyTo')] - - for root_comment in root_comments: - output += f" 💬 {root_comment['author']['login']} (updated: {root_comment['updatedAt']}) [Comment ID: {root_comment['fullDatabaseId']}]:\n" - output += f" {root_comment['body']}\n" - - # Find and show replies to this comment - replies = [c for c in comments if c.get('replyTo') and c['replyTo'].get('id') == root_comment['id']] - if replies: - for reply in replies: - output += f" ↳ {reply['author']['login']} (updated: {reply['updatedAt']}):\n" - output += f" {reply['body']}\n" - - output += "\n" - output += "\n" - - # General comments - if pr_data["comments"]["nodes"]: - for comment in pr_data["comments"]["nodes"]: - output += f"💬 Comment\n" - output += f" 👤 Comment by {comment['author']['login']} (updated: {comment['updatedAt']})\n" - output += f" 📝 Comment:\n" - output += f" {comment['body']}\n\n" - - console.print(Panel(escape(output), title=f"[bold green]PR #{pr_number} Review Data", border_style="blue")) - return output - - except Exception as e: - error_msg = f"Error: {e!s}\n\nStack trace:\n{traceback.format_exc()}" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg diff --git a/.github/scripts/python/handoff_to_user.py b/.github/scripts/python/handoff_to_user.py deleted file mode 100644 index 07ad331f1..000000000 --- a/.github/scripts/python/handoff_to_user.py +++ /dev/null @@ -1,34 +0,0 @@ -from rich.markup import escape -from rich.panel import Panel -from strands import tool -from strands.types.tools import ToolContext -from strands_tools.utils import console_util - -@tool(context=True) -def handoff_to_user(message: str, tool_context: ToolContext) -> str: - """ - Hand off control to the user with a message. - - Args: - message: The message to give to the user - - Returns: - The users response after handing back control - """ - console = console_util.create() - - console.print( - Panel( - escape(message), - title="[bold yellow]🤝 Handoff to User", - border_style="yellow", - ) - ) - - request_state = { - "stop_event_loop": True - } - tool_context.invocation_state["request_state"] = request_state - - # Return an empty string as this will break out of the event loop - return "" \ No newline at end of file diff --git a/.github/scripts/python/notebook.py b/.github/scripts/python/notebook.py deleted file mode 100644 index 0b5ba2ace..000000000 --- a/.github/scripts/python/notebook.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Notebook management tool for Strands Agents. - -This module provides comprehensive notebook operations for managing text-based notebooks -within agent workflows. Enables persistent note-taking, documentation, and context -preservation across agent sessions. - -Key Features: -1. Create and manage multiple named notebooks -2. Write content using string replacement or line insertion -3. Read entire notebooks or specific line ranges -4. List all available notebooks with metadata -5. Clear notebook contents when needed -6. Rich console output with formatted panels and tables -7. Agent state persistence for session continuity - -Usage Examples: -```python -from strands import Agent -from tools.notebook import notebook - -agent = Agent(tools=[notebook]) - -# Create a new notebook with initial content -result = agent.tool.notebook( - mode="create", - name="research_notes", - new_str="# Research Notes\n\nKey findings and observations..." -) - -# Write to notebook using line insertion -result = agent.tool.notebook( - mode="write", - name="research_notes", - insert_line=-1, # Append to end - new_str="- Important discovery about AI behavior patterns" -) - -# Read specific lines from notebook -result = agent.tool.notebook( - mode="read", - name="research_notes", - read_range=[1, 5] # Read first 5 lines -) - -# Replace text in notebook -result = agent.tool.notebook( - mode="write", - name="research_notes", - old_str="[ ] Todo item", - new_str="[x] Completed todo item" -) - -# List all notebooks -result = agent.tool.notebook(mode="list") - -# Clear notebook contents -result = agent.tool.notebook(mode="clear", name="research_notes") -``` -""" - -from typing import Any, Literal - -from rich import box -from rich.markup import escape -from rich.panel import Panel -from rich.table import Table -from strands import ToolContext, tool -from strands_tools.utils import console_util - - -@tool(context=True) -def notebook( - mode: Literal["create", "list", "read", "write", "clear"], - name: str = "default", - read_range: list[int] | None = None, - old_str: str | None = None, - new_str: str | None = None, - insert_line: str | int | None = None, - tool_context: ToolContext | None = None, -) -> str: - """ - Notebook tool for managing text notebooks. - - This tool provides a comprehensive interface for creating, reading, writing, listing, - and deleting text notebooks. Start writing notes in the default notebook which is avaiable - from the start, or create new notebooks to record notes on additional topics or tasks. - - Command Details: - -------------- - 1. write: - • Supports two types of write operations: - - String replacement: Uses old_str and new_str parameters - - Line insertion: Uses insert_line and new_str parameters - - 2. read: - • Reads contents of a notebook - • Supports reading specific line numbers with read_range parameter - - 3. create: - • Creates a new notebook with the specified name - • Optionally initializes with content using new_str parameter - • Defaults to empty content if new_str not provided - - 4. list: - • Lists all available notebook names - • Returns comma-separated list of notebook names - - 5. clear: - • Clears the contents of a notebook - - Args: - mode: The operation to perform: `create`, `list`, `read`, `write`, `clear`. - name: Name of the notebook to operate on. Defaults to "default". - read_range: Optional parameter of `view` command. Line range to show [start, end]. Supports negative indices. - old_str: String to replace in write mode when doing text replacement. - new_str: New string for replacement or insertion operations. - insert_line: Line number (int) or search text (str) for insertion point in write mode. - Supports negative indices. - - Returns: - Dict containing status and response content in the format: - { - "status": "success|error", - "content": [{"text": "Response message"}] - } - - Success case: Returns details about the operation performed - Error case: Returns information about what went wrong - - Examples: - 1. Create a notebook: - notebook(mode="create", name="notes") - - 2. List all notebooks: - notebook(mode="list") - - 3. Read entire notebook: - notebook(mode="read", name="notes") - - 4. Read specific lines: - notebook(mode="read", name="notes", read_range=[1, 5]) - - 5. Replace text: - notebook(mode="write", name="notes", old_str="[] Update the calendar", new_str="[x] Update the calendar") - - 6. Insert text after line 5: - notebook(mode="write", name="notes", insert_line=5, new_str="inserted text") - - 7. Insert text at end of notebook: - notebook(mode="write", name="notes", insert_line=-1, new_str="Appended text") - - 7. Insert text after finding a line: - notebook(mode="write", name="notes", insert_line="def function", new_str="# comment") - - 8. Clear notebook: - notebook(mode="clear", name="notes") - """ - console = console_util.create() - if tool_context is None: - raise ValueError("Tool context is required") - agent = tool_context.agent - - if agent.state.get("notebooks") is None: - agent.state.set("notebooks", {"default": ""}) - - notebooks: dict[str, Any] = agent.state.get("notebooks") - - if mode == "create": - notebooks[name] = new_str if new_str else "" - message = f"Created notebook '{name}'" + (" with specified content" if new_str else " (empty)") - console.print( - Panel( - escape(message + f":\n{new_str}" if new_str else ""), - title="[bold green]Success", - border_style="green", - ) - ) - agent.state.set("notebooks", notebooks) - return message - - elif mode == "list": - table = Table(title="📚 Available Notebooks", box=box.DOUBLE) - table.add_column("Name", style="cyan") - table.add_column("Lines", style="yellow") - table.add_column("Status", style="green") - - for nb_name in notebooks.keys(): - line_count = len(notebooks[nb_name].split("\n")) if notebooks[nb_name] else 0 - status = "Empty" if line_count == 0 else "Has content" - table.add_row(nb_name, str(line_count), status) - - console.print(table) - return f"Notebooks: {', '.join(notebooks.keys())}" - - elif mode == "read": - if name not in notebooks: - error_msg = f"Notebook '{name}' not found" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - - content = notebooks[name] - if read_range: - lines = content.split("\n") - start, end = read_range - # Handle negative indices - if start < 0: - start = len(lines) + start + 1 - if end < 0: - end = len(lines) + end + 1 - - selected_lines = [] - for line_num in range(start, end + 1): - if 1 <= line_num <= len(lines): - selected_lines.append(f"{line_num}: {lines[line_num - 1]}") - - result = "\n".join(selected_lines) if selected_lines else "No valid lines found" - console.print( - Panel( - escape(result), - title=f"[bold green]📖 {name} (lines {start}-{end})", - border_style="blue", - ) - ) - return result - - result = content if content else f"Notebook '{name}' is empty" - console.print(Panel(escape(result), title=f"[bold green]📖 {name}", border_style="blue")) - return result - - elif mode == "write": - if name not in notebooks: - error_msg = f"Notebook '{name}' not found" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - - # String replacement - if old_str is not None and new_str is not None: - if old_str not in notebooks[name]: - error_msg = f"String '{old_str}' not found in notebook '{name}'" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - - notebooks[name] = notebooks[name].replace(old_str, new_str) - agent.state.set("notebooks", notebooks) - - # Create git-style diff - old_lines = old_str.split("\n") - new_lines = new_str.split("\n") - diff_lines = [] - - for line in old_lines: - diff_lines.append(f"[red]-{escape(line)}[/red]") - for line in new_lines: - diff_lines.append(f"[green]+{escape(line)}[/green]") - - diff_content = "\n".join(diff_lines) - console.print(Panel(diff_content, title="[bold yellow]📝 Diff", border_style="yellow")) - - message = f"Replaced text in notebook '{name}'" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message - - # Line insertion - elif insert_line is not None and new_str is not None: - lines = notebooks[name].split("\n") - - # Check if string represents a number first - if isinstance(insert_line, str): - try: - insert_line = int(insert_line) - except ValueError: - pass # Keep as string for text search - - if isinstance(insert_line, str): - line_num = -1 - for i, line in enumerate(lines): - if insert_line in line: - line_num = i - break - if line_num == -1: - error_msg = f"Text '{insert_line}' not found in notebook '{name}'" - console.print( - Panel( - escape(error_msg), - title="[bold red]Error", - border_style="red", - ) - ) - raise ValueError(error_msg) - else: - # Handle negative indices - if insert_line < 0: - line_num = len(lines) + insert_line - else: - line_num = insert_line - 1 - - if 0 <= line_num <= len(lines): - lines.insert(line_num + 1, new_str) - notebooks[name] = "\n".join(lines) - agent.state.set("notebooks", notebooks) - message = f"Inserted text at line {line_num + 2} in notebook '{name}'" - console.print( - Panel( - escape(message), - title="[bold green]Success", - border_style="green", - ) - ) - console.print( - Panel( - escape(notebooks[name]), - title=f"[bold blue]📝 {name} Content", - border_style="blue", - ) - ) - return message - else: - error_msg = f"Line number {insert_line} out of range" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - - # No valid operation provided - else: - error_msg = "No valid write operation specified" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - - elif mode == "clear": - if name not in notebooks: - error_msg = f"Notebook '{name}' not found" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - raise ValueError(error_msg) - notebooks[name] = "" - agent.state.set("notebooks", notebooks) - message = f"Cleared notebook '{name}'" - console.print(Panel(escape(message), title="[bold green]Success", border_style="green")) - return message diff --git a/.github/scripts/python/requirements.txt b/.github/scripts/python/requirements.txt deleted file mode 100644 index 1ca2770ff..000000000 --- a/.github/scripts/python/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Strands packages - only what we need -strands-agents -strands-agents-tools - -# Additional dependencies for our specific tools -colorama -rich -requests>=2.28.0 \ No newline at end of file diff --git a/.github/scripts/python/str_replace_based_edit_tool.py b/.github/scripts/python/str_replace_based_edit_tool.py deleted file mode 100644 index 69c92c206..000000000 --- a/.github/scripts/python/str_replace_based_edit_tool.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Text editor tool for Strands Agents. - -A minimal implementation of Claude's text editor tool that supports: -- view: Read file contents or list directory contents -- str_replace: Replace text in files -- create: Create new files -- insert: Insert text at specific line numbers - -Based on Claude's text_editor_20250728 specification. -""" - -from pathlib import Path -from typing import List, Optional - -from rich.markup import escape -from rich.panel import Panel -from strands import tool -from strands_tools.utils import console_util - -console = console_util.create() - - -@tool -def str_replace_based_edit_tool( - command: str, - path: str, - old_str: str | None = None, - new_str: str | None = None, - file_text: str | None = None, - insert_line: str | None = None, - view_range: list[int] | None = None, -) -> str: - """Text editor tool for viewing and modifying files. - - Args: - command: The command to execute ("view", "str_replace", "create", "insert") - path: Path to the file or directory - old_str: Text to replace (for str_replace command) - new_str: Replacement text (for str_replace and insert commands) - file_text: Content for new file (for create command) - insert_line: Line number to insert after (for insert command) - view_range: [start_line, end_line] for viewing specific lines (for view command) - - Returns: - Result of the operation - """ - try: - console.print(Panel(f"Command: {command}, Path: {path}", title="[bold blue]Text Editor", border_style="blue")) - - if command == "view": - return _handle_view(path, view_range) - elif command == "str_replace": - if old_str is None or new_str is None: - error_msg = "Error: str_replace requires both old_str and new_str parameters" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - return _handle_str_replace(path, old_str, new_str) - elif command == "create": - if file_text is None: - error_msg = "Error: create requires file_text parameter" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - return _handle_create(path, file_text) - elif command == "insert": - if new_str is None or insert_line is None: - error_msg = "Error: insert requires both new_str and insert_line parameters" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - return _handle_insert(path, new_str, insert_line) - else: - error_msg = f"Error: Unknown command '{command}'" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - except Exception as e: - error_msg = f"Error: {str(e)}" - console.print(Panel(escape(error_msg), title="[bold red]Error", border_style="red")) - return error_msg - - -def _handle_view(path: str, view_range: Optional[List[int]] = None) -> str: - """Handle view command to read files or list directories.""" - path_obj = Path(path) - - if not path_obj.exists(): - return f"Error: Path '{path}' does not exist" - - if path_obj.is_dir(): - # List directory contents - try: - items = [] - for item in sorted(path_obj.iterdir()): - if item.is_dir(): - items.append(f"{item.name}/") - else: - items.append(item.name) - return "\n".join(items) - except PermissionError: - return f"Error: Permission denied accessing directory '{path}'" - - elif path_obj.is_file(): - # Read file contents - try: - with open(path_obj, 'r', encoding='utf-8') as f: - lines = f.readlines() - - # Apply view_range if specified - if view_range: - start_line, end_line = view_range - # Convert to 0-based indexing - start_idx = max(0, start_line - 1) if start_line > 0 else 0 - end_idx = len(lines) if end_line == -1 else min(len(lines), end_line) - lines = lines[start_idx:end_idx] - start_line_num = start_idx + 1 - else: - start_line_num = 1 - - # Add line numbers - numbered_lines = [] - for i, line in enumerate(lines): - line_num = start_line_num + i - numbered_lines.append(f"{line_num}: {line.rstrip()}") - - return "\n".join(numbered_lines) - except UnicodeDecodeError: - return f"Error: Cannot read '{path}' - file appears to be binary" - except PermissionError: - return f"Error: Permission denied reading file '{path}'" - - else: - return f"Error: '{path}' is not a regular file or directory" - - -def _handle_str_replace(path: str, old_str: str, new_str: str) -> str: - """Handle str_replace command to replace text in a file.""" - path_obj = Path(path) - - if not path_obj.exists(): - return f"Error: File '{path}' does not exist" - - if not path_obj.is_file(): - return f"Error: '{path}' is not a file" - - try: - # Read file content - with open(path_obj, 'r', encoding='utf-8') as f: - content = f.read() - - # Check if old_str exists - if old_str not in content: - return f"Error: Text '{old_str}' not found in file" - - # Count occurrences - count = content.count(old_str) - if count > 1: - return f"Error: Text '{old_str}' appears {count} times in file. Please be more specific." - - # Replace text - new_content = content.replace(old_str, new_str) - - # Write back to file - with open(path_obj, 'w', encoding='utf-8') as f: - f.write(new_content) - - success_msg = f"Successfully replaced text in '{path}'" - console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) - return success_msg - - except UnicodeDecodeError: - return f"Error: Cannot modify '{path}' - file appears to be binary" - except PermissionError: - return f"Error: Permission denied modifying file '{path}'" - - -def _handle_create(path: str, file_text: str) -> str: - """Handle create command to create a new file.""" - path_obj = Path(path) - - # Create parent directories if they don't exist - path_obj.parent.mkdir(parents=True, exist_ok=True) - - try: - with open(path_obj, 'w', encoding='utf-8') as f: - f.write(file_text) - - success_msg = f"Successfully created file '{path}'" - console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) - return success_msg - - except PermissionError: - return f"Error: Permission denied creating file '{path}'" - - -def _handle_insert(path: str, new_str: str, insert_line: int) -> str: - """Handle insert command to insert text at a specific line.""" - path_obj = Path(path) - - if not path_obj.exists(): - return f"Error: File '{path}' does not exist" - - if not path_obj.is_file(): - return f"Error: '{path}' is not a file" - - try: - # Read file lines - with open(path_obj, 'r', encoding='utf-8') as f: - lines = f.readlines() - - # Insert new text - if insert_line == 0: - # Insert at beginning - lines.insert(0, new_str + '\n') - elif insert_line >= len(lines): - # Insert at end - lines.append(new_str + '\n') - else: - # Insert after specified line (1-based indexing) - lines.insert(insert_line, new_str + '\n') - - # Write back to file - with open(path_obj, 'w', encoding='utf-8') as f: - f.writelines(lines) - - success_msg = f"Successfully inserted text in '{path}' at line {insert_line + 1}" - console.print(Panel(escape(success_msg), title="[bold green]Success", border_style="green")) - return success_msg - - except UnicodeDecodeError: - return f"Error: Cannot modify '{path}' - file appears to be binary" - except PermissionError: - return f"Error: Permission denied modifying file '{path}'" \ No newline at end of file diff --git a/.github/scripts/python/write_executor.py b/.github/scripts/python/write_executor.py deleted file mode 100755 index 6d3b6b84d..000000000 --- a/.github/scripts/python/write_executor.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -"""Write Executor Script for GitHub Operations. - -This script reads JSONL artifact files containing deferred GitHub operations -and executes them using functions from github_tools.py. It's designed to run -after the strands-agent-runner to publish any write commands or commits. -""" - -import argparse -import json -import logging -import os -from pathlib import Path -from typing import Any, Dict - -from github_tools import GitHubOperation - -# Import write only github_tools functions for dynamic execution -from github_tools import ( - create_issue, - update_issue, - add_issue_comment, - create_pull_request, - update_pull_request, - reply_to_review_comment, -) - -# Configure structured logging -logging.basicConfig( - format="%(levelname)s | %(name)s | %(message)s", - handlers=[logging.StreamHandler()], - level=logging.INFO -) -logger = logging.getLogger("write_executor") - - -def get_function_mapping() -> Dict[str, Any]: - """Get mapping of function names to actual functions.""" - return { - create_issue.tool_name: create_issue, - update_issue.tool_name: update_issue, - add_issue_comment.tool_name: add_issue_comment, - create_pull_request.tool_name: create_pull_request, - update_pull_request.tool_name: update_pull_request, - reply_to_review_comment.tool_name: reply_to_review_comment, - } - - -def process_jsonl_file(file_path: Path, default_issue_id: int | None = None): - """Process JSONL file and execute operations. - - Args: - file_path: Path to the JSONL artifact file - default_issue_id: Default issue ID to use for fallback operations - - Returns: - Tuple of (total_operations, successful_operations, failed_operations) - """ - function_map = get_function_mapping() - - logger.info(f"Starting JSONL processing: {file_path}") - total_ops = 0 - with open(file_path, 'r') as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - - total_ops += 1 - logger.info(f"Processing operation {total_ops} (line {line_num})") - - try: - # Parse JSONL entry - operation: GitHubOperation = json.loads(line) - func_name = operation.get("function") - args = operation.get('args', []) - kwargs = operation.get('kwargs', {}) - - if not func_name: - logger.error(f"Line {line_num}: Missing function name") - continue - - # Get function from mapping - if func_name not in function_map: - logger.error(f"Line {line_num}: Unknown function '{func_name}'") - continue - - func = function_map[func_name] - - # Set default issue ID for create_pull_request if not already set - if func_name == "create_pull_request" and default_issue_id and not kwargs.get("fallback_issue_id"): - kwargs["fallback_issue_id"] = default_issue_id - - # Execute function - logger.info(f"Executing {func_name} with args={args}, kwargs={kwargs}") - result = func(*args, **kwargs) - - logger.info(f"Line {line_num}: Operation {func_name} completed successfully") - logger.info(f"Function output: {str(result)}") - - except Exception as e: - logger.error(f"Line {line_num}: Execution error - {e}") - - - logger.info(f"JSONL processing completed.") - - -def main(): - """Main entry point for the write executor script.""" - parser = argparse.ArgumentParser( - description="Execute deferred GitHub operations from JSONL artifact files" - ) - parser.add_argument( - "artifact_file", - help="Path to JSONL artifact file containing deferred operations" - ) - parser.add_argument( - "--issue-id", - type=int, - help="Default issue ID to use for fallback operations" - ) - - args = parser.parse_args() - artifact_path = Path(args.artifact_file) - - logger.info(f"Write executor started with artifact file: {artifact_path}") - if args.issue_id: - logger.info(f"Default issue ID set to: {args.issue_id}") - - # Check if file exists - if not artifact_path.exists(): - logger.warning(f"Artifact file not found: {artifact_path}") - logger.warning("No deferred operations to execute") - return - - # Check if file is empty - if artifact_path.stat().st_size == 0: - logger.info("Artifact file is empty") - logger.info("No deferred operations to execute") - return - - # Set environment to enable write operations - os.environ['GITHUB_WRITE'] = 'true' - logger.info("GitHub write mode enabled") - - logger.info(f"Processing deferred operations from: {artifact_path}") - - # Process the JSONL file - process_jsonl_file(artifact_path, args.issue_id) - -if __name__ == "__main__": - main() diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 397f4300d..65c785f30 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -2,40 +2,26 @@ name: Secure Integration test on: pull_request_target: - branches: main + branches: [main] + merge_group: # Run tests in merge queue + types: [checks_requested] jobs: authorization-check: + name: Check access permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.collab-check.outputs.result }} + approval-env: ${{ steps.auth.outputs.result }} steps: - - name: Collaborator Check - uses: actions/github-script@v8 - id: collab-check + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main with: - result-encoding: string - script: | - try { - const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: context.repo.owner, - repo: context.repo.repo, - username: context.payload.pull_request.user.login, - }); - const permission = permissionResponse.data.permission; - const hasWriteAccess = ['write', 'admin'].includes(permission); - if (!hasWriteAccess) { - console.log(`User ${context.payload.pull_request.user.login} does not have write access to the repository (permission: ${permission})`); - return "manual-approval" - } else { - console.log(`Verifed ${context.payload.pull_request.user.login} has write access. Auto Approving PR Checks.`) - return "auto-approve" - } - } catch (error) { - console.log(`${context.payload.pull_request.user.login} does not have write access. Requiring Manual Approval to run PR Checks.`) - return "manual-approval" - } + skip-check: ${{ github.event_name == 'merge_group' }} + username: ${{ github.event.pull_request.user.login || 'invalid' }} + allowed-roles: 'triage,write,admin' + check-access-and-checkout: runs-on: ubuntu-latest needs: authorization-check diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index 803f19e48..6c3328192 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -23,93 +23,40 @@ on: jobs: authorization-check: if: startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch' + name: Check access permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.collab-check.outputs.result || steps.auto-approve.outputs.result }} + approval-env: ${{ steps.auth.outputs.result }} steps: - - name: Collaborator Check - if: github.event_name != 'workflow_dispatch' - uses: actions/github-script@v8 - id: collab-check + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main with: - result-encoding: string - script: | - try { - const permissionResponse = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: context.repo.owner, - repo: context.repo.repo, - username: context.payload.comment.user.login, - }); - const permission = permissionResponse.data.permission; - const hasWriteAccess = ['write', 'admin'].includes(permission); - if (!hasWriteAccess) { - console.log(`User ${context.payload.comment.user.login} does not have write access to the repository (permission: ${permission})`); - return "manual-approval" - } else { - console.log(`Verified ${context.payload.comment.user.login} has write access. Auto Approving strands command.`) - return "auto-approve" - } - } catch (error) { - console.log(`${context.payload.comment.user.login} does not have write access. Requiring Manual Approval to run strands command.`) - return "manual-approval" - } - - - name: Auto-approve for workflow dispatch - if: github.event_name == 'workflow_dispatch' - id: auto-approve - uses: actions/github-script@v8 - with: - result-encoding: string - script: | - return "auto-approve" + skip-check: ${{ github.event_name == 'workflow_dispatch' }} + username: ${{ github.event.comment.user.login || 'invalid' }} + allowed-roles: 'triage,write,admin' setup-and-process: needs: [authorization-check] environment: ${{ needs.authorization-check.outputs.approval-env }} permissions: + # Needed to create a branch for the Implementer Agent contents: write + # These both are needed to add the `strands-running` label to issues and prs issues: write pull-requests: write runs-on: ubuntu-latest - outputs: - branch: ${{ steps.process.outputs.branch_name }} - session_id: ${{ steps.process.outputs.session_id }} - system_prompt: ${{ steps.process.outputs.system_prompt }} - prompt: ${{ steps.process.outputs.prompt }} steps: - - name: Add strands-running label - uses: actions/github-script@v8 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ inputs.issue_id || github.event.issue.number }}, - labels: ['strands-running'] - }); - - - name: Checkout repository - uses: actions/checkout@v6 + - name: Parse input + id: parse + uses: strands-agents/devtools/strands-command/actions/strands-input-parser@main with: - sparse-checkout: | - .github + issue_id: ${{ inputs.issue_id }} + command: ${{ inputs.command }} + session_id: ${{ inputs.session_id }} - # Outputs: branch_name, session_id, system_prompt, prompt - - name: Process input - id: process - uses: actions/github-script@v8 - with: - script: | - const processInput = require('./.github/scripts/javascript/process-input.cjs'); - await processInput(context, github, core, { - issue_id: '${{ inputs.issue_id }}', - command: '${{ inputs.command }}', - session_id: '${{ inputs.session_id }}' - }); - - execute-readonly: + execute-readonly-agent: needs: [setup-and-process] permissions: contents: read @@ -119,66 +66,26 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 60 steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - sparse-checkout: | - .github + + # Add any steps here to set up the environment for the Agent in your repo + # setup node, setup python, or any other dependencies - name: Run Strands Agent id: agent-runner - uses: ./.github/actions/strands-agent-runner + uses: strands-agents/devtools/strands-command/actions/strands-agent-runner@main with: - system_prompt: ${{ needs.setup-and-process.outputs.system_prompt }} - session_id: ${{ needs.setup-and-process.outputs.session_id }} - task_prompt: ${{ needs.setup-and-process.outputs.prompt }} aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} sessions_bucket: ${{ secrets.AGENT_SESSIONS_BUCKET }} write_permission: 'false' - ref: ${{ needs.setup-and-process.outputs.branch }} - execute-write: - needs: [setup-and-process, execute-readonly] + finalize: + needs: [setup-and-process, execute-readonly-agent] permissions: contents: write issues: write pull-requests: write - id-token: write # Required for OIDC runs-on: ubuntu-latest timeout-minutes: 30 steps: - - name: Checkout repository - uses: actions/checkout@v6 - with: - sparse-checkout: | - .github - - name: Execute write operations - uses: ./.github/actions/strands-write-executor - with: - ref: ${{ needs.setup-and-process.outputs.branch }} - issue_id: ${{ inputs.issue_id || github.event.issue.number }} - - - cleanup: - needs: [authorization-check, setup-and-process, execute-readonly, execute-write] - if: always() - permissions: - issues: write - pull-requests: write - runs-on: ubuntu-latest - steps: - - name: Remove strands-running label - uses: actions/github-script@v8 - with: - script: | - try { - await github.rest.issues.removeLabel({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ inputs.issue_id || github.event.issue.number }}, - name: 'strands-running' - }); - } catch (error) { - console.log('Label removal failed (may not exist):', error.message); - } + uses: strands-agents/devtools/strands-command/actions/strands-finalize@main From 138750c57a6ecf5ddba5ae706209b0c2736ebd52 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung <51226101+kevmyung@users.noreply.github.com> Date: Tue, 27 Jan 2026 06:01:57 -1000 Subject: [PATCH 090/279] feat(bedrock): add automatic prompt caching support (#1438) --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 54 ++++++++++++- src/strands/models/model.py | 15 +++- tests/strands/models/test_bedrock.py | 114 ++++++++++++++++++++++++++- 4 files changed, 181 insertions(+), 5 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index d5f88d09a..be6a96549 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,12 +7,13 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import Model +from .model import CacheConfig, Model __all__ = [ "bedrock", "model", "BedrockModel", + "CacheConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 567a2e147..a3cea7cfe 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -29,7 +29,7 @@ from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys -from .model import Model +from .model import CacheConfig, Model logger = logging.getLogger(__name__) @@ -73,7 +73,8 @@ class BedrockConfig(TypedDict, total=False): additional_args: Any additional arguments to include in the request additional_request_fields: Additional fields to include in the Bedrock request additional_response_field_paths: Additional response field paths to extract - cache_prompt: Cache point type for the system prompt + cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) + cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. cache_tools: Cache point type for tools guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. @@ -99,6 +100,7 @@ class BedrockConfig(TypedDict, total=False): additional_request_fields: dict[str, Any] | None additional_response_field_paths: list[str] | None cache_prompt: str | None + cache_config: CacheConfig | None cache_tools: str | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None @@ -172,6 +174,15 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + @property + def _supports_caching(self) -> bool: + """Whether this model supports prompt caching. + + Returns True for Claude models on Bedrock. + """ + model_id = self.config.get("model_id", "").lower() + return "claude" in model_id or "anthropic" in model_id + @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore """Update the Bedrock Model configuration with the provided arguments. @@ -322,6 +333,33 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} + def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: + """Inject a cache point at the end of the last assistant message. + + Args: + messages: List of messages to inject cache point into (modified in place). + """ + if not messages: + return + + last_assistant_idx: int | None = None + for msg_idx, msg in enumerate(messages): + content = msg.get("content", []) + for block_idx, block in reversed(list(enumerate(content))): + if "cachePoint" in block: + del content[block_idx] + logger.warning( + "msg_idx=<%s>, block_idx=<%s> | stripped existing cache point (auto mode manages cache points)", + msg_idx, + block_idx, + ) + if msg.get("role") == "assistant": + last_assistant_idx = msg_idx + + if last_assistant_idx is not None and messages[last_assistant_idx].get("content"): + messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}}) + logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx) + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -330,6 +368,7 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: - Eagerly filtering content blocks to only include Bedrock-supported fields - Ensuring all message content blocks are properly formatted for the Bedrock API - Optionally wrapping the last user message in guardrailConverseContent blocks + - Injecting cache points when cache_config is set with strategy="auto" Args: messages: List of messages to format @@ -396,6 +435,17 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: "Filtered DeepSeek reasoningContent content blocks from messages - https://api-docs.deepseek.com/guides/reasoning_model#multi-round-conversation" ) + # Inject cache point into cleaned_messages (not original messages) if cache_config is set + cache_config = self.config.get("cache_config") + if cache_config and cache_config.strategy == "auto": + if self._supports_caching: + self._inject_cache_point(cleaned_messages) + else: + logger.warning( + "model_id=<%s> | cache_config is enabled but this model does not support caching", + self.config.get("model_id"), + ) + return cleaned_messages def _should_include_tool_result_status(self) -> bool: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index e6630f807..550ee22e9 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -3,7 +3,8 @@ import abc import logging from collections.abc import AsyncGenerator, AsyncIterable -from typing import Any, TypeVar +from dataclasses import dataclass +from typing import Any, Literal, TypeVar from pydantic import BaseModel @@ -16,6 +17,18 @@ T = TypeVar("T", bound=BaseModel) +@dataclass +class CacheConfig: + """Configuration for prompt caching. + + Attributes: + strategy: Caching strategy to use. + - "auto": Automatically inject cachePoint at optimal positions + """ + + strategy: Literal["auto"] = "auto" + + class Model(abc.ABC): """Abstract base class for Agent model providers. diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 833b14729..e92018f35 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -12,7 +12,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, @@ -2241,3 +2241,115 @@ async def test_format_request_with_guardrail_latest_message(model): # Latest user message image should also be wrapped assert "guardContent" in formatted_messages[2]["content"][1] assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" + + +def test_supports_caching_true_for_claude(bedrock_client): + """Test that supports_caching returns True for Claude models.""" + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + assert model._supports_caching is True + + model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") + assert model2._supports_caching is True + + +def test_supports_caching_false_for_non_claude(bedrock_client): + """Test that supports_caching returns False for non-Claude models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0") + assert model._supports_caching is False + + +def test_inject_cache_point_adds_to_last_assistant(bedrock_client): + """Test that _inject_cache_point adds cache point to last assistant message.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages[1]["content"]) == 2 + assert "cachePoint" in cleaned_messages[1]["content"][-1] + assert cleaned_messages[1]["content"][-1]["cachePoint"]["type"] == "default" + + +def test_inject_cache_point_no_assistant_message(bedrock_client): + """Test that _inject_cache_point does nothing when no assistant message exists.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages) == 1 + assert len(cleaned_messages[0]["content"]) == 1 + + +def test_inject_cache_point_skipped_for_non_claude(bedrock_client): + """Test that cache point injection is skipped for non-Claude models.""" + model = BedrockModel(model_id="amazon.nova-pro-v1:0", cache_config=CacheConfig(strategy="auto")) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[1]["content"]) == 1 + assert "cachePoint" not in formatted[1]["content"][0] + + +def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): + """Test that _format_bedrock_messages does not mutate original messages.""" + import copy + + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + original_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + ] + + messages_before = copy.deepcopy(original_messages) + formatted = model._format_bedrock_messages(original_messages) + + assert original_messages == messages_before + assert "cachePoint" not in original_messages[1]["content"][-1] + assert "cachePoint" in formatted[1]["content"][-1] + + +def test_inject_cache_point_strips_existing_cache_points(bedrock_client): + """Test that _inject_cache_point strips existing cache points and adds new one at correct position.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + # Messages with existing cache points in various positions + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}, {"cachePoint": {"type": "default"}}]}, + {"role": "assistant", "content": [{"text": "First response"}, {"cachePoint": {"type": "default"}}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + # All old cache points should be stripped + assert len(cleaned_messages[0]["content"]) == 1 # user: only text + assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text + + # New cache point should be at end of last assistant message + assert len(cleaned_messages[3]["content"]) == 2 + assert "cachePoint" in cleaned_messages[3]["content"][-1] From 27b9bc3acc8953f5832ce1e5e6060510567f8029 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 27 Jan 2026 22:16:08 +0200 Subject: [PATCH 091/279] feat(hooks): add retry mechanism for tool calls (#1556) --- src/strands/hooks/events.py | 18 +- src/strands/tools/executors/_executor.py | 186 ++++++------ .../strands/tools/executors/test_executor.py | 279 ++++++++++++++++++ tests_integ/test_tool_retry_hook.py | 69 +++++ 4 files changed, 467 insertions(+), 85 deletions(-) create mode 100644 tests_integ/test_tool_retry_hook.py diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 1faa8a917..ad40dfd7f 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -158,6 +158,18 @@ class AfterToolCallEvent(HookEvent): Note: This event uses reverse callback ordering, meaning callbacks registered later will be invoked first during cleanup. + Tool Retrying: + When ``retry`` is set to True by a hook callback, the tool executor will + discard the current tool result and invoke the tool again. This has important + implications for streaming consumers: + + - ToolStreamEvents (intermediate streaming events) from the discarded tool execution + will have already been emitted to callers before the retry occurs. Agent invokers + consuming streamed events should be prepared to handle this scenario, potentially + by tracking retry state or implementing idempotent event processing + - ToolResultEvent is NOT emitted for discarded attempts - only the final attempt's + result is emitted and added to the conversation history + Attributes: selected_tool: The tool that was invoked. It may be None if tool lookup failed. tool_use: The tool parameters that were passed to the tool invoked. @@ -165,6 +177,9 @@ class AfterToolCallEvent(HookEvent): result: The result of the tool invocation. Either a ToolResult on success or an Exception if the tool execution failed. cancel_message: The cancellation message if the user cancelled the tool call. + retry: Whether to retry the tool invocation. Can be set by hook callbacks + to trigger a retry. When True, the current result is discarded and the + tool is called again. Defaults to False. """ selected_tool: AgentTool | None @@ -173,9 +188,10 @@ class AfterToolCallEvent(HookEvent): result: ToolResult exception: Exception | None = None cancel_message: str | None = None + retry: bool = False def _can_write(self, name: str) -> bool: - return name == "result" + return name in ["result", "retry"] @property def should_reverse_callbacks(self) -> bool: diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 6d58c5c75..ef000fbd6 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -148,109 +148,127 @@ async def _stream( } ) - before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( - agent, tool_func, tool_use, invocation_state - ) - - if interrupts: - yield ToolInterruptEvent(tool_use, interrupts) - return - - if before_event.cancel_tool: - cancel_message = ( - before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + # Retry loop for tool execution - hooks can set after_event.retry = True to retry + while True: + before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( + agent, tool_func, tool_use, invocation_state ) - yield ToolCancelEvent(tool_use, cancel_message) - cancel_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": cancel_message}], - } + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) - return - - try: - selected_tool = before_event.selected_tool - tool_use = before_event.tool_use - invocation_state = before_event.invocation_state - - if not selected_tool: - if tool_func == selected_tool: - logger.error( - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", - tool_name, - list(agent.tool_registry.registry.keys()), - ) - else: - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", - tool_name, - str(tool_use.get("toolUseId")), - ) + if before_event.cancel_tool: + cancel_message = ( + before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + ) + yield ToolCancelEvent(tool_use, cancel_message) - result: ToolResult = { + cancel_result: ToolResult = { "toolUseId": str(tool_use.get("toolUseId")), "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], + "content": [{"text": cancel_message}], } after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return - if structured_output_context.is_enabled: - kwargs["structured_output_context"] = structured_output_context - async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): - # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() - # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. - # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent - # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in - # ToolStreamEvent and the last event is just the result. - - if isinstance(event, ToolInterruptEvent): - yield event - return - - if isinstance(event, ToolResultEvent): - # below the last "event" must point to the tool_result - event = event.tool_result - break - if isinstance(event, ToolStreamEvent): - yield event - else: - yield ToolStreamEvent(tool_use, event) + try: + selected_tool = before_event.selected_tool + tool_use = before_event.tool_use + invocation_state = before_event.invocation_state + + if not selected_tool: + if tool_func == selected_tool: + logger.error( + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", + tool_name, + list(agent.tool_registry.registry.keys()), + ) + else: + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", + tool_name, + str(tool_use.get("toolUseId")), + ) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result + ) + # Check if retry requested for unknown tool error + # Use getattr because BidiAfterToolCallEvent doesn't have retry attribute + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return + if structured_output_context.is_enabled: + kwargs["structured_output_context"] = structured_output_context + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): + # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() + # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. + # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent + # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in + # ToolStreamEvent and the last event is just the result. + + if isinstance(event, ToolInterruptEvent): + yield event + return + + if isinstance(event, ToolResultEvent): + # below the last "event" must point to the tool_result + event = event.tool_result + break + + if isinstance(event, ToolStreamEvent): + yield event + else: + yield ToolStreamEvent(tool_use, event) + + result = cast(ToolResult, event) - result = cast(ToolResult, event) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result + ) - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result - ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) + continue - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return - except Exception as e: - logger.exception("tool_name=<%s> | failed to process tool", tool_name) - error_result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Error: {str(e)}"}], - } + except Exception as e: + logger.exception("tool_name=<%s> | failed to process tool", tool_name) + error_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Error: {str(e)}"}], + } - after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, error_result, exception=e - ) - yield ToolResultEvent(after_event.result) - tool_results.append(after_event.result) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, error_result, exception=e + ) + # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) + if getattr(after_event, "retry", False): + logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name) + continue + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return @staticmethod async def _stream_with_trace( diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 8139fbf66..78e35c2aa 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -4,6 +4,7 @@ import pytest import strands +from strands.experimental.hooks.events import BidiAfterToolCallEvent from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace @@ -479,3 +480,281 @@ async def test_executor_stream_updates_invocation_state_with_agent( # Verify that the invocation_state was updated with the agent assert "agent" in empty_invocation_state assert empty_invocation_state["agent"] is agent + + +@pytest.mark.asyncio +async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): + """Test default behavior when retry is not set - tool executes once.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event with first attempt's content + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True causes tool re-execution.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Set retry=True on first call only + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called twice due to retry + assert call_count["count"] == 2 + + # Only final result is yielded (first attempt's result was discarded) + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + # tool_results only contains the final result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_2"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_retry_true_emits_events_from_both_attempts( + executor, agent, tool_results, invocation_state, alist +): + """Test that ToolStreamEvents from discarded attempt ARE emitted, but ToolResultEvent is NOT. + + This validates the documented behavior: 'Streaming events from the discarded + tool execution will have already been emitted to callers before the retry occurs.' + + Key distinction: + - ToolStreamEvent (intermediate): Yielded immediately, visible from BOTH attempts + - ToolResultEvent (final): Only yielded for the final attempt, discarded on retry + """ + call_count = {"count": 0} + + @strands.tool(name="streaming_tool") + def streaming_tool(): + return "unused" + + # Provide streaming implementation (same pattern as exception_tool fixture) + async def tool_stream(_tool_use, _invocation_state, **kwargs): + call_count["count"] += 1 + yield f"streaming_from_attempt_{call_count['count']}" + yield ToolResultEvent( + {"toolUseId": "1", "status": "success", "content": [{"text": f"result_{call_count['count']}"}]} + ) + + streaming_tool.stream = tool_stream + agent.tool_registry.register_tool(streaming_tool) + + # Set retry=True on first call + def retry_once(event): + if isinstance(event, AfterToolCallEvent) and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once) + + tool_use: ToolUse = {"name": "streaming_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool called twice + assert call_count["count"] == 2 + + # Streaming events from BOTH attempts are emitted (documented behavior) + stream_events = [e for e in tru_events if isinstance(e, ToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0] == ToolStreamEvent(tool_use, "streaming_from_attempt_1") + assert stream_events[1] == ToolStreamEvent(tool_use, "streaming_from_attempt_2") + + # Only final ToolResultEvent is emitted + result_events = [e for e in tru_events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0].tool_result["content"][0]["text"] == "result_2" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_false(executor, agent, tool_results, invocation_state, alist): + """Test that explicitly setting retry=False does not retry.""" + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + # Explicitly set retry=False + def no_retry(event): + if isinstance(event, AfterToolCallEvent): + event.retry = False + return event + + agent.hooks.add_callback(AfterToolCallEvent, no_retry) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + + # Tool should be called exactly once + assert call_count["count"] == 1 + + # Single result event + assert len(tru_events) == 1 + assert tru_events[0].tool_result == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + # tool_results should contain the result + assert len(tool_results) == 1 + assert tool_results[0] == {"toolUseId": "1", "status": "success", "content": [{"text": "attempt_1"}]} + + +@pytest.mark.asyncio +async def test_executor_stream_bidi_event_no_retry_attribute(executor, agent, tool_results, invocation_state, alist): + """Test that BidiAfterToolCallEvent (which lacks retry attribute) doesn't cause retry. + + This tests the getattr(after_event, "retry", False) fallback for events without retry. + """ + call_count = {"count": 0} + + @strands.tool(name="counting_tool") + def counting_tool(): + call_count["count"] += 1 + return f"attempt_{call_count['count']}" + + agent.tool_registry.register_tool(counting_tool) + + tool_use: ToolUse = {"name": "counting_tool", "toolUseId": "1", "input": {}} + result: strands.types.tools.ToolResult = { + "toolUseId": "1", + "status": "success", + "content": [{"text": "attempt_1"}], + } + + # Create a BidiAfterToolCallEvent (which has no retry attribute) + bidi_event = BidiAfterToolCallEvent( + agent=agent, + selected_tool=counting_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + ) + + # Patch _invoke_after_tool_call_hook to return BidiAfterToolCallEvent + async def mock_after_hook(*args, **kwargs): + return bidi_event, [] + + with unittest.mock.patch.object(ToolExecutor, "_invoke_after_tool_call_hook", mock_after_hook): + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool should be called once - no retry since BidiAfterToolCallEvent has no retry attr + assert call_count["count"] == 1 + + # Result should be returned + assert len(tru_events) == 1 + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_exception(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True works when tool raises an exception. + + Covers the exception path retry check. + """ + call_count = {"count": 0} + + @strands.tool(name="flaky_tool") + def flaky_tool(): + call_count["count"] += 1 + if call_count["count"] == 1: + raise RuntimeError("First call fails") + return "success" + + agent.tool_registry.register_tool(flaky_tool) + + # Retry once on error (check result status, not exception attribute) + def retry_on_error(event): + if isinstance(event, AfterToolCallEvent) and event.result.get("status") == "error" and call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_on_error) + + tool_use: ToolUse = {"name": "flaky_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Tool called twice (1 exception + 1 success) + assert call_count["count"] == 2 + + # Final result is success + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_retry_after_unknown_tool(executor, agent, tool_results, invocation_state, alist): + """Test that retry=True triggers retry loop for unknown tool. + + Covers the unknown tool path retry check. Tool lookup happens before retry loop, + so even after retry the tool remains unknown - this test verifies the retry + mechanism is triggered, not that it resolves the unknown tool. + """ + hook_call_count = {"count": 0} + + # Retry once on first unknown tool error + def retry_once_on_unknown(event): + if isinstance(event, AfterToolCallEvent): + hook_call_count["count"] += 1 + # Retry only on first call + if hook_call_count["count"] == 1: + event.retry = True + return event + + agent.hooks.add_callback(AfterToolCallEvent, retry_once_on_unknown) + + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + tru_events = await alist(stream) + + # Hook called twice (retry was triggered) + assert hook_call_count["count"] == 2 + + # Final result is still error (tool remains unknown after retry) + assert len(tru_events) == 1 + assert tru_events[0].tool_result["status"] == "error" + assert "Unknown tool" in tru_events[0].tool_result["content"][0]["text"] diff --git a/tests_integ/test_tool_retry_hook.py b/tests_integ/test_tool_retry_hook.py new file mode 100644 index 000000000..3e35ff5e6 --- /dev/null +++ b/tests_integ/test_tool_retry_hook.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +"""Integration tests for tool retry hook mechanism. + +Tests that setting AfterToolCallEvent.retry=True causes tool re-execution. +Uses direct tool invocation to test the executor-level retry, not model behavior. +""" + +from strands import Agent, tool +from strands.hooks import AfterToolCallEvent + + +def test_tool_retry_hook_causes_reexecution(): + """Test that setting retry=True on AfterToolCallEvent causes tool re-execution. + + Verifies: + 1. Tool is called again when retry=True + 2. Hook receives AfterToolCallEvent for BOTH attempts + 3. Same tool_use_id is used (proves executor retry, not model re-calling) + """ + state = {"call_count": 0} + + @tool(name="flaky_tool") + def flaky_tool(message: str) -> str: + """A tool that fails once then succeeds. + + Args: + message: A message to include in the response. + """ + state["call_count"] += 1 + if state["call_count"] == 1: + raise RuntimeError("First call fails") + return f"Success on attempt {state['call_count']}" + + hook_calls: list[dict] = [] + + def retry_on_first_error(event: AfterToolCallEvent) -> None: + tool_use_id = str(event.tool_use.get("toolUseId", "")) + hook_calls.append( + { + "tool_use_id": tool_use_id, + "status": event.result.get("status"), + "attempt": state["call_count"], + } + ) + + # Retry once on error + if event.result.get("status") == "error" and state["call_count"] == 1: + event.retry = True + + agent = Agent(tools=[flaky_tool]) + agent.hooks.add_callback(AfterToolCallEvent, retry_on_first_error) + + # Direct tool invocation bypasses model - tests executor retry mechanism + result = agent.tool.flaky_tool(message="test") + + # Tool was called twice (1 failure + 1 success) + assert state["call_count"] == 2 + + # Hook received AfterToolCallEvent for BOTH attempts + assert len(hook_calls) == 2 + assert hook_calls[0]["status"] == "error" + assert hook_calls[0]["attempt"] == 1 + assert hook_calls[1]["status"] == "success" + assert hook_calls[1]["attempt"] == 2 + + # Both calls used the same tool_use_id (executor retry, not new model call) + assert hook_calls[0]["tool_use_id"] == hook_calls[1]["tool_use_id"] + + assert result["status"] == "success" From 4d0ffe84b97b5acec48686ac3b9e2c2bd0fdbf53 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 27 Jan 2026 17:54:51 -0500 Subject: [PATCH 092/279] feat(tools): move ToolProvider out of experimental namespace (#1567) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- AGENTS.md | 4 +- src/strands/agent/agent.py | 2 +- src/strands/experimental/bidi/agent/agent.py | 2 +- src/strands/experimental/tools/__init__.py | 21 ++++- src/strands/tools/__init__.py | 2 + src/strands/tools/mcp/mcp_client.py | 8 +- src/strands/tools/registry.py | 2 +- .../{experimental => }/tools/tool_provider.py | 2 +- .../tools/test_tool_provider_alias.py | 83 +++++++++++++++++++ tests/strands/tools/test_registry.py | 3 +- .../tools/test_registry_tool_provider.py | 2 +- 11 files changed, 114 insertions(+), 17 deletions(-) rename src/strands/{experimental => }/tools/tool_provider.py (97%) create mode 100644 tests/strands/experimental/tools/test_tool_provider_alias.py diff --git a/AGENTS.md b/AGENTS.md index 8b4394cc5..71e83835d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -55,6 +55,7 @@ strands-agents/ │ ├── tools/ # Tool system │ │ ├── decorator.py # @tool decorator │ │ ├── tools.py # Tool base classes +│ │ ├── tool_provider.py # ToolProvider interface │ │ ├── registry.py # Tool registration │ │ ├── loader.py # Dynamic tool loading │ │ ├── watcher.py # Hot reload @@ -139,8 +140,7 @@ strands-agents/ │ │ │ ├── context_providers/ │ │ │ ├── core/ │ │ │ └── handlers/ -│ │ └── tools/ # Experimental tools -│ │ └── tool_provider.py +│ │ └── tools/ # Experimental tools (deprecation shims) │ │ │ ├── __init__.py # Public API exports │ ├── interrupt.py # Interrupt handling diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index cacc69ece..e2ac3aa71 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -31,7 +31,7 @@ from ..tools._tool_helpers import generate_missing_tool_result_content if TYPE_CHECKING: - from ..experimental.tools import ToolProvider + from ..tools import ToolProvider from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from ..hooks import ( AfterInvocationEvent, diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 11bea96e5..8c68e780e 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -25,11 +25,11 @@ from ....tools.executors import ConcurrentToolExecutor from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry +from ....tools.tool_provider import ToolProvider from ....tools.watcher import ToolWatcher from ....types.content import Message, Messages from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent -from ...tools import ToolProvider from .._async import _TaskGroup, stop_all from ..models.model import BidiModel from ..types.agent import BidiAgentInput diff --git a/src/strands/experimental/tools/__init__.py b/src/strands/experimental/tools/__init__.py index ad693f8ac..a23b7a10c 100644 --- a/src/strands/experimental/tools/__init__.py +++ b/src/strands/experimental/tools/__init__.py @@ -1,5 +1,22 @@ """Experimental tools package.""" -from .tool_provider import ToolProvider +import warnings +from typing import Any -__all__ = ["ToolProvider"] +_DEPRECATED_NAMES = {"ToolProvider"} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_NAMES: + from ...tools import ToolProvider + + warnings.warn( + f"{name} has been moved to production. Use {name} from strands.tools instead.", + DeprecationWarning, + stacklevel=2, + ) + return ToolProvider + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/tools/__init__.py b/src/strands/tools/__init__.py index c61f79748..ada49369d 100644 --- a/src/strands/tools/__init__.py +++ b/src/strands/tools/__init__.py @@ -5,6 +5,7 @@ from .decorator import tool from .structured_output import convert_pydantic_to_tool_spec +from .tool_provider import ToolProvider from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec __all__ = [ @@ -14,4 +15,5 @@ "normalize_schema", "normalize_tool_spec", "convert_pydantic_to_tool_spec", + "ToolProvider", ] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1aff22a1e..833d55e07 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -40,11 +40,11 @@ from pydantic import AnyUrl from typing_extensions import Protocol, TypedDict -from ...experimental.tools import ToolProvider from ...types import PaginatedList from ...types.exceptions import MCPClientInitializationError, ToolProviderException from ...types.media import ImageFormat from ...types.tools import AgentTool, ToolResultContent, ToolResultStatus +from ..tool_provider import ToolProvider from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation from .mcp_types import MCPToolResult, MCPTransport @@ -106,10 +106,6 @@ class MCPClient(ToolProvider): The connection runs in a background thread to avoid blocking the main application thread while maintaining communication with the MCP service. When structured content is available from MCP tools, it will be returned as the last item in the content array of the ToolResult. - - Warning: - This class implements the experimental ToolProvider interface and its methods - are subject to change. """ def __init__( @@ -207,7 +203,7 @@ def start(self) -> "MCPClient": raise MCPClientInitializationError("the client initialization failed") from e return self - # ToolProvider interface methods (experimental, as ToolProvider is experimental) + # ToolProvider interface methods async def load_tools(self, **kwargs: Any) -> Sequence[AgentTool]: """Load and return tools from the MCP server. diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index f9787a182..a5e4132bb 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -19,9 +19,9 @@ from typing_extensions import TypedDict from .._async import run_async -from ..experimental.tools import ToolProvider from ..tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec +from . import ToolProvider from .loader import load_tool_from_string, load_tools_from_module from .tools import _COMPOSITION_KEYWORDS, PythonAgentTool, normalize_schema, normalize_tool_spec diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/tools/tool_provider.py similarity index 97% rename from src/strands/experimental/tools/tool_provider.py rename to src/strands/tools/tool_provider.py index c40d1b572..002c57d73 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/tools/tool_provider.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from ...types.tools import AgentTool + from ..types.tools import AgentTool class ToolProvider(ABC): diff --git a/tests/strands/experimental/tools/test_tool_provider_alias.py b/tests/strands/experimental/tools/test_tool_provider_alias.py new file mode 100644 index 000000000..58a2b9e20 --- /dev/null +++ b/tests/strands/experimental/tools/test_tool_provider_alias.py @@ -0,0 +1,83 @@ +"""Tests to verify that experimental ToolProvider alias works with deprecation warning. + +This test module ensures that the experimental ToolProvider alias maintains +backwards compatibility and can be used interchangeably with the actual +ToolProvider type from strands.tools. +""" + +import sys + +import pytest + +from strands.tools import ToolProvider + + +def test_experimental_alias_is_same_type(): + """Verify that experimental ToolProvider alias is identical to the actual type.""" + from strands.experimental.tools import ToolProvider as ExperimentalToolProvider + + assert ExperimentalToolProvider is ToolProvider + + +def test_deprecation_warning_on_import(captured_warnings): + """Verify that importing ToolProvider from experimental emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.tools" in sys.modules: + del sys.modules["strands.experimental.tools"] + + # Clear any existing warnings + captured_warnings.clear() + + # Import from experimental - this should trigger the warning + from strands.experimental import tools + + _ = tools.ToolProvider + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert "ToolProvider" in str(warning.message) + assert "strands.tools" in str(warning.message) + + +def test_deprecation_warning_on_direct_import(captured_warnings): + """Verify that direct import from experimental.tools emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.tools" in sys.modules: + del sys.modules["strands.experimental.tools"] + + # Clear any existing warnings + captured_warnings.clear() + + # Direct import - this should trigger the warning + from strands.experimental.tools import ToolProvider as _ # noqa: F401 + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert "ToolProvider" in str(warning.message) + assert "strands.tools" in str(warning.message) + + +def test_attribute_error_on_unknown_attribute(): + """Verify that accessing unknown attributes raises AttributeError.""" + import strands.experimental.tools as tools_module + + with pytest.raises(AttributeError, match="has no attribute"): + _ = tools_module.NonExistentClass + + +def test_no_warning_on_production_import(captured_warnings): + """Verify that importing from strands.tools does not emit deprecation warning.""" + # Clear any existing warnings + captured_warnings.clear() + + # Import from production - should NOT trigger warning + from strands.tools import ToolProvider as _ # noqa: F401 + + # Filter for ToolProvider-related deprecation warnings + tool_provider_warnings = [ + w for w in captured_warnings if "ToolProvider" in str(w.message) and issubclass(w.category, DeprecationWarning) + ] + + assert len(tool_provider_warnings) == 0 diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index d44936f3e..ed96f2b6a 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -7,8 +7,7 @@ import pytest import strands -from strands.experimental.tools import ToolProvider -from strands.tools import PythonAgentTool +from strands.tools import PythonAgentTool, ToolProvider from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.mcp import MCPClient from strands.tools.registry import ToolRegistry diff --git a/tests/strands/tools/test_registry_tool_provider.py b/tests/strands/tools/test_registry_tool_provider.py index fdf4abb0a..25a4edacb 100644 --- a/tests/strands/tools/test_registry_tool_provider.py +++ b/tests/strands/tools/test_registry_tool_provider.py @@ -4,7 +4,7 @@ import pytest -from strands.experimental.tools.tool_provider import ToolProvider +from strands.tools import ToolProvider from strands.tools.registry import ToolRegistry from tests.fixtures.mock_agent_tool import MockAgentTool From 62cc949e3a94cdf93b033a61ce1299cd39e4f5fe Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 27 Jan 2026 18:00:28 -0500 Subject: [PATCH 093/279] [FIX] models - gemini - start and stop reasoningContent (#1557) --- src/strands/models/gemini.py | 13 +++-- tests/strands/models/test_gemini.py | 68 ++++++++++++++++++++++--- tests_integ/models/test_model_gemini.py | 17 +++++++ 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 855e1ef5c..192a363d3 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -443,8 +443,8 @@ async def stream( response = await client.models.generate_content_stream(**request) yield self._format_chunk({"chunk_type": "message_start"}) - yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + data_type: str | None = None tool_used = False candidate = None event = None @@ -462,15 +462,22 @@ async def stream( tool_used = True if part.text: + new_data_type = "reasoning_content" if part.thought else "text" + if new_data_type != data_type: + if data_type is not None: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": new_data_type}) + data_type = new_data_type yield self._format_chunk( { "chunk_type": "content_delta", - "data_type": "reasoning_content" if part.thought else "text", + "data_type": data_type, "data": part, }, ) - yield self._format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + if data_type is not None: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) yield self._format_chunk( { "chunk_type": "message_stop", diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 70f5032d8..86ab2fea5 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -523,11 +523,9 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat tru_chunks = await alist(model.stream(messages)) exp_chunks = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, {"contentBlockStop": {}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, ] @@ -573,6 +571,68 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_reasoning_and_text(gemini_client, model, messages, agenerator, alist): + """Test that both reasoning and text content are captured in separate blocks.""" + gemini_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + text="thinking about math", + thought=True, + thought_signature=b"sig1", + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + text="2 + 2 = 4", + thought=False, + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=5, + ), + ), + ] + ) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "sig1", "text": "thinking about math"}}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "2 + 2 = 4"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 4, "totalTokens": 5}, "metrics": {"latencyMs": 0}}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_max_tokens(gemini_client, model, messages, agenerator, alist): gemini_client.aio.models.generate_content_stream.return_value = agenerator( @@ -623,8 +683,6 @@ async def test_stream_response_none_candidates(gemini_client, model, messages, a tru_chunks = await alist(model.stream(messages)) exp_chunks = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, ] @@ -643,8 +701,6 @@ async def test_stream_response_empty_stream(gemini_client, model, messages, agen tru_chunks = await alist(model.stream(messages)) exp_chunks = [ {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, ] assert tru_chunks == exp_chunks diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index 5643d159e..4c01c0b71 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -202,3 +202,20 @@ def test_agent_with_gemini_code_execution_tool(gemini_tool_model): result_turn2 = agent("Summarize that into a single number") assert "5117" in str(result_turn2) + + +def test_agent_with_reasoning_content(model, assistant_agent): + """Test that reasoning content is captured in message history.""" + + model.update_config( + params={ + "thinking_config": { + "thinking_budget": 1024, + "include_thoughts": True, + }, + }, + ) + + result = assistant_agent("Think about what 2+2 is") + assert "reasoningContent" in result.message["content"][0] + assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] From 694c4a7e135f8258f1138bf94efcc436d3db4896 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:25:12 -0500 Subject: [PATCH 094/279] feat(agent): update AgentResult __str__ priority order (#1553) --- src/strands/agent/agent_result.py | 21 ++--- tests/strands/agent/test_agent_result.py | 98 ++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index 8f9241a67..63b7a0d4a 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -36,17 +36,23 @@ class AgentResult: structured_output: BaseModel | None = None def __str__(self) -> str: - """Get the agent's last message as a string. + """Return a string representation of the agent result. - This method extracts and concatenates all text content from the final message, ignoring any non-text content - like images or structured data. If there's no text content but structured output is present, it serializes - the structured output instead. + Priority order: + 1. Interrupts (if present) → stringified list of interrupt dicts + 2. Structured output (if present) → JSON string + 3. Text content from message → concatenated text blocks Returns: - The agent's last message as a string. + String representation based on the priority order above. """ - content_array = self.message.get("content", []) + if self.interrupts: + return str([interrupt.to_dict() for interrupt in self.interrupts]) + + if self.structured_output: + return self.structured_output.model_dump_json() + content_array = self.message.get("content", []) result = "" for item in content_array: if isinstance(item, dict): @@ -59,9 +65,6 @@ def __str__(self) -> str: if isinstance(content, dict) and "text" in content: result += content.get("text", "") + "\n" - if not result and self.structured_output: - result = self.structured_output.model_dump_json() - return result @classmethod diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 6e4c2c91a..fa9ec4ad9 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.types.content import Message from strands.types.streaming import StopReason @@ -185,7 +186,7 @@ def test__init__structured_output_defaults_to_none(mock_metrics, simple_message: def test__str__with_structured_output(mock_metrics, simple_message: Message): - """Test that str() is not affected by structured_output.""" + """Test that str() returns structured output JSON when structured_output is present.""" structured_output = StructuredOutputModel(name="test", value=42) result = AgentResult( @@ -196,11 +197,11 @@ def test__str__with_structured_output(mock_metrics, simple_message: Message): structured_output=structured_output, ) - # The string representation should only include the message text, not structured output + # When structured_output is present, it takes priority over message text message_string = str(result) - assert message_string == "Hello world!\n" - assert "test" not in message_string - assert "42" not in message_string + assert message_string == structured_output.model_dump_json() + assert "test" in message_string + assert "42" in message_string def test__str__empty_message_with_structured_output(mock_metrics, empty_message: Message): @@ -283,3 +284,90 @@ def test__str__mixed_text_and_citations_content(mock_metrics, mixed_text_and_cit message_string = str(result) assert message_string == "Introduction paragraph\nCited content here.\nConclusion paragraph\n" + + +def test__str__with_interrupts(mock_metrics, simple_message: Message): + """Test that str() returns stringified interrupts when present.""" + interrupts = [ + Interrupt(id="int-1", name="approval", reason="Need user approval"), + Interrupt(id="int-2", name="input", reason="Need more info"), + ] + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + ) + + message_string = str(result) + + # Should contain stringified interrupt dicts + assert "int-1" in message_string + assert "approval" in message_string + assert "Need user approval" in message_string + assert "int-2" in message_string + assert "input" in message_string + assert "Need more info" in message_string + + +def test__str__interrupts_priority_over_structured_output(mock_metrics, simple_message: Message): + """Test that interrupts take priority over structured_output in str().""" + interrupts = [Interrupt(id="int-1", name="approval", reason="Needs approval")] + structured_output = StructuredOutputModel(name="test", value=42) + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + structured_output=structured_output, + ) + + message_string = str(result) + + # Should return interrupts, not structured output + assert "int-1" in message_string + assert "approval" in message_string + # Should NOT contain structured output + assert "test" not in message_string or "approval" in message_string # "test" might appear but not from structured + assert '"value": 42' not in message_string + + +def test__str__interrupts_priority_over_text_content(mock_metrics, simple_message: Message): + """Test that interrupts take priority over message text content in str().""" + interrupts = [Interrupt(id="int-1", name="confirm", reason="Please confirm")] + + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=interrupts, + ) + + message_string = str(result) + + # Should return interrupts, not message text + assert "int-1" in message_string + assert "confirm" in message_string + assert "Hello world!" not in message_string + + +def test__str__empty_interrupts_returns_agent_message(mock_metrics, simple_message: Message): + """Test that empty interrupts list falls through to other content.""" + result = AgentResult( + stop_reason="end_turn", + message=simple_message, + metrics=mock_metrics, + state={}, + interrupts=[], + ) + + message_string = str(result) + + # Empty list is falsy, should fall through to text content + assert message_string == "Hello world!\n" + From e8fc991ae8a8f1c541d46f26238535183f39fd29 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 28 Jan 2026 09:34:36 -0500 Subject: [PATCH 095/279] callback handler - fix reporting of tool when missing delta (#1573) --- src/strands/handlers/callback_handler.py | 17 ++-- .../strands/handlers/test_callback_handler.py | 96 +++++-------------- 2 files changed, 31 insertions(+), 82 deletions(-) diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d449f76da..45b7efda8 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -14,7 +14,6 @@ def __init__(self, verbose_tool_use: bool = True) -> None: verbose_tool_use: Print out verbose information about tool calls. """ self.tool_count = 0 - self.previous_tool_use = None self._verbose_tool_use = verbose_tool_use def __call__(self, **kwargs: Any) -> None: @@ -25,12 +24,12 @@ def __call__(self, **kwargs: Any) -> None: - reasoningText (Optional[str]): Reasoning text to print if provided. - data (str): Text content to stream. - complete (bool): Whether this is the final chunk of a response. - - current_tool_use (dict): Information about the current tool being used. + - event (dict): ModelStreamChunkEvent. """ reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) - current_tool_use = kwargs.get("current_tool_use", {}) + tool_use = kwargs.get("event", {}).get("contentBlockStart", {}).get("start", {}).get("toolUse") if reasoningText: print(reasoningText, end="") @@ -38,13 +37,11 @@ def __call__(self, **kwargs: Any) -> None: if data: print(data, end="" if not complete else "\n") - if current_tool_use and current_tool_use.get("name"): - if self.previous_tool_use != current_tool_use: - self.previous_tool_use = current_tool_use - self.tool_count += 1 - if self._verbose_tool_use: - tool_name = current_tool_use.get("name", "Unknown tool") - print(f"\nTool #{self.tool_count}: {tool_name}") + if tool_use: + self.tool_count += 1 + if self._verbose_tool_use: + tool_name = tool_use["name"] + print(f"\nTool #{self.tool_count}: {tool_name}") if complete and data: print("\n") diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 224823ef7..0d72c8563 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -72,56 +72,21 @@ def test_call_with_data_complete(handler, mock_print): mock_print.assert_any_call("\n") -def test_call_with_current_tool_use_new(handler, mock_print): - """Test calling the handler with a new tool use.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - - handler(current_tool_use=current_tool_use) - - # Should print tool information - mock_print.assert_called_once_with("\nTool #1: test_tool") - - # Should update the handler state - assert handler.tool_count == 1 - assert handler.previous_tool_use == current_tool_use - - -def test_call_with_current_tool_use_same(handler, mock_print): - """Test calling the handler with the same tool use twice.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - - # First call - handler(current_tool_use=current_tool_use) - mock_print.reset_mock() - - # Second call with same tool use - handler(current_tool_use=current_tool_use) - - # Should not print tool information again - mock_print.assert_not_called() - - # Tool count should not increase - assert handler.tool_count == 1 - - -def test_call_with_current_tool_use_different(handler, mock_print): +def test_call_with_tool_uses(handler, mock_print): """Test calling the handler with different tool uses.""" - first_tool_use = {"name": "first_tool", "input": {"param": "value1"}} - second_tool_use = {"name": "second_tool", "input": {"param": "value2"}} - - # First call - handler(current_tool_use=first_tool_use) - mock_print.reset_mock() + first_event = {"contentBlockStart": {"start": {"toolUse": {"name": "first_tool"}}}} + second_event = {"contentBlockStart": {"start": {"toolUse": {"name": "second_tool"}}}} - # Second call with different tool use - handler(current_tool_use=second_tool_use) + handler(event=first_event) + handler(event=second_event) - # Should print info for the new tool - mock_print.assert_called_once_with("\nTool #2: second_tool") + assert mock_print.call_args_list == [ + unittest.mock.call("\nTool #1: first_tool"), + unittest.mock.call("\nTool #2: second_tool"), + ] # Tool count should increase assert handler.tool_count == 2 - assert handler.previous_tool_use == second_tool_use def test_call_with_data_and_complete_extra_newline(handler, mock_print): @@ -146,42 +111,30 @@ def test_call_with_message_no_effect(handler, mock_print): def test_call_with_multiple_parameters(handler, mock_print): """Test calling handler with multiple parameters.""" - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} + event = {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}} - handler(data="Test output", complete=True, current_tool_use=current_tool_use) + handler(data="Test output", complete=True, event=event) - # Should print data with newline, an extra newline for completion, and tool information - assert mock_print.call_count == 3 - mock_print.assert_any_call("Test output", end="\n") - mock_print.assert_any_call("\n") - mock_print.assert_any_call("\nTool #1: test_tool") - - -def test_unknown_tool_name_handling(handler, mock_print): - """Test handling of a tool use without a name.""" - # The SDK implementation doesn't have a fallback for tool uses without a name field - # It checks for both presence of current_tool_use and current_tool_use.get("name") - current_tool_use = {"input": {"param": "value"}, "name": "Unknown tool"} - - handler(current_tool_use=current_tool_use) - - # Should print the tool information - mock_print.assert_called_once_with("\nTool #1: Unknown tool") + # Should print data with newline, tool information, and an extra newline for completion + assert mock_print.call_args_list == [ + unittest.mock.call("Test output", end="\n"), + unittest.mock.call("\nTool #1: test_tool"), + unittest.mock.call("\n"), + ] def test_tool_use_empty_object(handler, mock_print): - """Test handling of an empty tool use object.""" + """Test handling of an empty tool use object in event.""" # Tool use is an empty dict - current_tool_use = {} + event = {"contentBlockStart": {"start": {"toolUse": {}}}} - handler(current_tool_use=current_tool_use) + handler(event=event) # Should not print anything mock_print.assert_not_called() # Should not update state assert handler.tool_count == 0 - assert handler.previous_tool_use is None def test_composite_handler_forwards_to_all_handlers(): @@ -193,7 +146,7 @@ def test_composite_handler_forwards_to_all_handlers(): kwargs = { "data": "Test output", "complete": True, - "current_tool_use": {"name": "test_tool", "input": {"param": "value"}}, + "event": {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}}, } # Call the composite handler @@ -215,12 +168,11 @@ def test_verbose_tool_use_disabled(mock_print): handler = PrintingCallbackHandler(verbose_tool_use=False) assert handler._verbose_tool_use is False - current_tool_use = {"name": "test_tool", "input": {"param": "value"}} - handler(current_tool_use=current_tool_use) + event = {"contentBlockStart": {"start": {"toolUse": {"name": "test_tool"}}}} + handler(event=event) # Should not print tool information when verbose_tool_use is False mock_print.assert_not_called() - # Should still update tool count and previous_tool_use + # Should still update tool count assert handler.tool_count == 1 - assert handler.previous_tool_use == current_tool_use From f814458a6a9ca57dae4d98bc544b6f266a119fd0 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 28 Jan 2026 10:36:20 -0500 Subject: [PATCH 096/279] feat(hooks): Add invocation state (#1550) --- src/strands/agent/agent.py | 10 +++-- src/strands/event_loop/event_loop.py | 3 ++ src/strands/hooks/events.py | 21 ++++++++++- tests/strands/agent/hooks/test_events.py | 37 +++++++++++++++++++ tests/strands/agent/test_agent_hooks.py | 31 +++++++++------- .../agent/test_conversation_manager.py | 2 +- tests/strands/event_loop/test_event_loop.py | 15 ++++---- .../experimental/hooks/test_hook_aliases.py | 4 +- 8 files changed, 94 insertions(+), 29 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e2ac3aa71..05c3af191 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -474,7 +474,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self, invocation_state={})) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -515,7 +515,7 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu return event["output"] finally: - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -657,7 +657,7 @@ async def _run_loop( Events from the event loop cycle. """ before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( - BeforeInvocationEvent(agent=self, messages=messages) + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages) ) messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages @@ -695,7 +695,9 @@ async def _run_loop( finally: self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, result=agent_result)) + await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 41122efc5..9fe645f80 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -318,6 +318,7 @@ async def _handle_model_execution( await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, + invocation_state=invocation_state, ) ) @@ -343,6 +344,7 @@ async def _handle_model_execution( after_model_call_event = AfterModelCallEvent( agent=agent, + invocation_state=invocation_state, stop_response=AfterModelCallEvent.ModelStopResponse( stop_reason=stop_reason, message=message, @@ -370,6 +372,7 @@ async def _handle_model_execution( # Exception is automatically recorded by use_span with end_on_exit=True after_model_call_event = AfterModelCallEvent( agent=agent, + invocation_state=invocation_state, exception=e, ) await agent.hooks.invoke_callbacks_async(after_model_call_event) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index ad40dfd7f..8d3e5d280 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -4,7 +4,7 @@ """ import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -48,10 +48,14 @@ class BeforeInvocationEvent(HookEvent): - Agent.structured_output Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. messages: The input messages for this invocation. Can be modified by hooks to redact or transform content before processing. """ + invocation_state: dict[str, Any] = field(default_factory=dict) messages: Messages | None = None def _can_write(self, name: str) -> bool: @@ -75,11 +79,15 @@ class AfterInvocationEvent(HookEvent): - Agent.structured_output Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. result: The result of the agent invocation, if available. This will be None when invoked from structured_output methods, as those return typed output directly rather than AgentResult. """ + invocation_state: dict[str, Any] = field(default_factory=dict) result: "AgentResult | None" = None @property @@ -208,9 +216,14 @@ class BeforeModelCallEvent(HookEvent): that will be sent to the model. Note: This event is not fired for invocations to structured_output. + + Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. """ - pass + invocation_state: dict[str, Any] = field(default_factory=dict) @dataclass @@ -239,6 +252,9 @@ class AfterModelCallEvent(HookEvent): conversation history Attributes: + invocation_state: State and configuration passed through the agent invocation. + This can include shared context for multi-agent coordination, request tracking, + and dynamic configuration. stop_response: The model response data if invocation was successful, None if failed. exception: Exception if the model invocation failed, None if successful. retry: Whether to retry the model invocation. Can be set by hook callbacks @@ -258,6 +274,7 @@ class ModelStopResponse: message: Message stop_reason: StopReason + invocation_state: dict[str, Any] = field(default_factory=dict) stop_response: ModelStopResponse | None = None exception: Exception | None = None retry: bool = False diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 83cb1af24..762b77452 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -5,9 +5,11 @@ from strands.agent.agent_result import AgentResult from strands.hooks import ( AfterInvocationEvent, + AfterModelCallEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, + BeforeModelCallEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -170,6 +172,41 @@ def test_after_invocation_event_properties_not_writable(agent): with pytest.raises(AttributeError, match="Property agent is not writable"): event.agent = Mock() + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + event.invocation_state = {} + + +def test_invocation_state_is_available_in_invocation_events(agent): + """Test that invocation_state is accessible in BeforeInvocationEvent and AfterInvocationEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeInvocationEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterInvocationEvent(agent=agent, invocation_state=invocation_state, result=None) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + +def test_invocation_state_is_available_in_model_call_events(agent): + """Test that invocation_state is accessible in BeforeModelCallEvent and AfterModelCallEvent.""" + invocation_state = {"session_id": "test-123", "request_id": "req-456"} + + before_event = BeforeModelCallEvent(agent=agent, invocation_state=invocation_state) + assert before_event.invocation_state == invocation_state + assert before_event.invocation_state["session_id"] == "test-123" + assert before_event.invocation_state["request_id"] == "req-456" + + after_event = AfterModelCallEvent(agent=agent, invocation_state=invocation_state) + assert after_event.invocation_state == invocation_state + assert after_event.invocation_state["session_id"] == "test-123" + assert after_event.invocation_state["request_id"] == "req-456" + + + def test_before_invocation_event_messages_default_none(agent): """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e8b7e5077..8ff81295a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -160,14 +160,15 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], @@ -193,9 +194,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", @@ -204,7 +206,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent, result=result) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 @@ -215,8 +217,9 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m iterator = agent.stream_async("test message") await anext(iterator) - # Verify first event is BeforeInvocationEvent with messages + # Verify first event is BeforeInvocationEvent with invocation_state and messages assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].invocation_state is not None assert hook_provider.events_received[0].messages is not None assert hook_provider.events_received[0].messages[0]["role"] == "user" @@ -230,14 +233,15 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={ "content": [{"toolUse": tool_use}], @@ -263,9 +267,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message=mock_model.agent_responses[1], stop_reason="end_turn", @@ -274,7 +279,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[3]) - assert next(events) == AfterInvocationEvent(agent=agent, result=result) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY, result=result) assert len(agent.messages) == 4 @@ -289,8 +294,8 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added @@ -306,8 +311,8 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a assert length == 2 - assert next(events) == BeforeInvocationEvent(agent=agent) - assert next(events) == AfterInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, invocation_state=ANY) + assert next(events) == AfterInvocationEvent(agent=agent, invocation_state=ANY) assert len(agent.messages) == 0 # no new messages added diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index ae18a9131..46876d8e5 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -362,7 +362,7 @@ def test_per_turn_dynamic_change(): mock_agent = MagicMock() mock_agent.messages = [] - event = BeforeModelCallEvent(agent=mock_agent) + event = BeforeModelCallEvent(agent=mock_agent, invocation_state={}) # Initially disabled with patch.object(manager, "apply_management") as mock_apply: diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index a76a5b6b5..8c6155e20 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -855,27 +855,28 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert count == 9 # 1st call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 2nd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 3rd call - throttled - assert next(events) == BeforeModelCallEvent(agent=agent) - expected_after = AfterModelCallEvent(agent=agent, stop_response=None, exception=exception) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + expected_after = AfterModelCallEvent(agent=agent, invocation_state=ANY, stop_response=None, exception=exception) expected_after.retry = True assert next(events) == expected_after # 4th call - successful - assert next(events) == BeforeModelCallEvent(agent=agent) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) assert next(events) == AfterModelCallEvent( agent=agent, + invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" ), diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index 2da8a6f90..b229c1c2d 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -68,7 +68,7 @@ def test_after_tool_call_event_type_equality(): def test_before_model_call_event_type_equality(): """Verify that BeforeModelInvocationEvent alias has the same type identity.""" - before_model_event = BeforeModelCallEvent(agent=Mock()) + before_model_event = BeforeModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(before_model_event, BeforeModelInvocationEvent) assert isinstance(before_model_event, BeforeModelCallEvent) @@ -76,7 +76,7 @@ def test_before_model_call_event_type_equality(): def test_after_model_call_event_type_equality(): """Verify that AfterModelInvocationEvent alias has the same type identity.""" - after_model_event = AfterModelCallEvent(agent=Mock()) + after_model_event = AfterModelCallEvent(agent=Mock(), invocation_state={}) assert isinstance(after_model_event, AfterModelInvocationEvent) assert isinstance(after_model_event, AfterModelCallEvent) From 4e4534e79bc6735df727d6207fc0c760736a0499 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 28 Jan 2026 14:51:13 -0500 Subject: [PATCH 097/279] test(steering): Fix failing integ tests (#1580) --- src/strands/experimental/steering/core/handler.py | 4 +--- .../steering/handlers/llm/llm_handler.py | 8 ++++++-- tests/strands/agent/test_agent_result.py | 1 - .../experimental/steering/core/test_handler.py | 2 +- tests_integ/steering/test_tool_steering.py | 12 ++++++++++-- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index fd00a27fc..403a73414 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -115,9 +115,7 @@ def _handle_tool_steering_action( logger.debug("tool_name=<%s> | tool call proceeding", tool_name) elif isinstance(action, Guide): logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) - event.cancel_tool = ( - f"Tool call cancelled given new guidance. {action.reason}. Consider this approach and continue" - ) + event.cancel_tool = f"Tool call cancelled. {action.reason} You MUST follow this guidance immediately." elif isinstance(action, Interrupt): logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 379dc684a..6d0a31eeb 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -50,9 +50,13 @@ def __init__( system_prompt: System prompt defining steering guidance rules prompt_mapper: Custom prompt mapper for evaluation prompts model: Optional model override for steering evaluation - context_providers: List of context providers for populating steering context + context_providers: List of context providers for populating steering context. + Defaults to [LedgerProvider()] if None. Pass an empty list to disable + context providers. """ - providers = context_providers or [LedgerProvider()] + providers: list[SteeringContextProvider] = ( + [LedgerProvider()] if context_providers is None else context_providers + ) super().__init__(context_providers=providers) self.system_prompt = system_prompt self.prompt_mapper = prompt_mapper or DefaultPromptMapper() diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index fa9ec4ad9..a4478c3ca 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -370,4 +370,3 @@ def test__str__empty_interrupts_returns_agent_message(mock_metrics, simple_messa # Empty list is falsy, should fall through to text content assert message_string == "Hello world!\n" - diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index cbe2b3783..04d3a56c1 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -95,7 +95,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): await handler._provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message - expected_message = "Tool call cancelled given new guidance. Test guidance. Consider this approach and continue" + expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." assert event.cancel_tool == expected_message diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py index 75073c648..5036c759c 100644 --- a/tests_integ/steering/test_tool_steering.py +++ b/tests_integ/steering/test_tool_steering.py @@ -75,8 +75,16 @@ def test_agent_with_tool_steering_e2e(): """End-to-end test of agent with steering handler guiding tool choice.""" handler = LLMSteeringHandler( system_prompt=( - "When agents try to use send_email, guide them to use send_notification instead for better delivery." - ) + "CRITICAL INSTRUCTION - READ CAREFULLY:\n\n" + "You are a steering agent. Your ONLY job is to decide based on the tool name.\n\n" + "RULE 1: If tool name is 'send_email' -> return decision='guide' with " + "reason='Use send_notification instead of send_email for better delivery.'\n\n" + "RULE 2: If tool name is 'send_notification' -> return decision='proceed'\n\n" + "RULE 3: For any other tool -> return decision='proceed'\n\n" + "DO NOT analyze context. DO NOT consider arguments. ONLY look at the tool name.\n" + "The tool name in this request is the ONLY thing that matters." + ), + context_providers=[], # Disable ledger to avoid confusing context ) agent = Agent(tools=[send_email, send_notification], hooks=[handler]) From 53147210abcc278d7cf7da24f8c38985667c89e4 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 12:16:38 -0500 Subject: [PATCH 098/279] Increase pytest timeout to 45 seconds (#1586) --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a16132881..7f816880d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", "ruff>=0.13.0,<0.15.0", "tenacity>=9.0.0,<10.0.0", @@ -146,6 +147,7 @@ dependencies = [ "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", + "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", "moto>=5.1.0,<6.0.0", ] @@ -239,6 +241,7 @@ convention = "google" testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" +timeout = 45 [tool.coverage.run] From c48045e89a288da695adc40698b4a34afb36da89 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 15:07:17 -0500 Subject: [PATCH 099/279] Publish integ tests results to cloudwatch (#1587) --- .github/scripts/upload-integ-test-metrics.py | 147 +++++++++++++++++++ .github/workflows/integration-test.yml | 7 + pyproject.toml | 5 +- 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 .github/scripts/upload-integ-test-metrics.py diff --git a/.github/scripts/upload-integ-test-metrics.py b/.github/scripts/upload-integ-test-metrics.py new file mode 100644 index 000000000..28595d647 --- /dev/null +++ b/.github/scripts/upload-integ-test-metrics.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +import sys +import xml.etree.ElementTree as ET +from datetime import datetime +from dataclasses import dataclass +from typing import Any, Literal, TypedDict +import os +import boto3 + +STRANDS_METRIC_NAMESPACE = 'Strands/Tests' + + + +class Dimension(TypedDict): + Name: str + Value: str + + +class MetricDatum(TypedDict): + MetricName: str + Dimensions: list[Dimension] + Value: float + Unit: str + Timestamp: datetime + + +@dataclass +class TestResult: + name: str + classname: str + duration: float + outcome: Literal['failed', 'skipped', 'passed'] + + +def parse_junit_xml(xml_file_path: str) -> list[TestResult]: + try: + tree = ET.parse(xml_file_path) + except FileNotFoundError: + print(f"Warning: XML file not found: {xml_file_path}") + return [] + except ET.ParseError as e: + print(f"Warning: Failed to parse XML: {e}") + return [] + + results = [] + root = tree.getroot() + + for testcase in root.iter('testcase'): + name = testcase.get('name') + classname = testcase.get('classname') + duration = float(testcase.get('time', 0.0)) + + if not name or not classname: + continue + + if testcase.find('failure') is not None or testcase.find('error') is not None: + outcome = 'failed' + elif testcase.find('skipped') is not None: + outcome = 'skipped' + else: + outcome = 'passed' + + results.append(TestResult(name, classname, duration, outcome)) + + return results + + +def build_metric_data(test_results: list[TestResult], repository: str) -> list[MetricDatum]: + metrics: list[MetricDatum] = [] + timestamp = datetime.utcnow() + + for test in test_results: + test_name = f"{test.classname}.{test.name}" + dimensions: list[Dimension] = [ + Dimension(Name='TestName', Value=test_name), + Dimension(Name='Repository', Value=repository) + ] + + metrics.append(MetricDatum( + MetricName='TestPassed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'passed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestFailed', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'failed' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestSkipped', + Dimensions=dimensions, + Value=1.0 if test.outcome == 'skipped' else 0.0, + Unit='Count', + Timestamp=timestamp + )) + + metrics.append(MetricDatum( + MetricName='TestDuration', + Dimensions=dimensions, + Value=test.duration, + Unit='Seconds', + Timestamp=timestamp + )) + + return metrics + + +def publish_metrics(metric_data: list[dict[str, Any]], region: str): + cloudwatch = boto3.client('cloudwatch', region_name=region) + + batch_size = 1000 + for i in range(0, len(metric_data), batch_size): + batch = metric_data[i:i + batch_size] + try: + cloudwatch.put_metric_data(Namespace=STRANDS_METRIC_NAMESPACE, MetricData=batch) + print(f"Published {len(batch)} metrics to CloudWatch") + except Exception as e: + print(f"Warning: Failed to publish metrics batch: {e}") + + +def main(): + if len(sys.argv) != 3: + print("Usage: python upload-integ-test-metrics.py ") + sys.exit(0) + + xml_file = sys.argv[1] + repository = sys.argv[2] + region = os.environ.get('AWS_REGION', 'us-east-1') + + test_results = parse_junit_xml(xml_file) + if not test_results: + print("No test results found") + sys.exit(1) + + print(f"Found {len(test_results)} test results") + metric_data = build_metric_data(test_results, repository) + publish_metrics(metric_data, region) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 65c785f30..bbcdfde25 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -37,6 +37,7 @@ jobs: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 mask-aws-account-id: true + - name: Checkout head commit uses: actions/checkout@v6 with: @@ -57,3 +58,9 @@ jobs: id: tests run: | hatch test tests_integ + + - name: Publish test metrics to CloudWatch + if: always() + run: | + pip install --no-cache-dir boto3 + python .github/scripts/upload-integ-test-metrics.py ./build/test-results.xml ${{ github.event.repository.name }} diff --git a/pyproject.toml b/pyproject.toml index 7f816880d..ba635cc48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ dependencies = [ "pytest-asyncio>=1.0.0,<1.4.0", "pytest-timeout>=2.0.0,<3.0.0", "pytest-xdist>=3.0.0,<4.0.0", + "pytest-timeout>=2.0.0,<3.0.0", "moto>=5.1.0,<6.0.0", ] @@ -240,7 +241,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" -addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" +addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi --junit-xml=build/test-results.xml" timeout = 45 @@ -298,7 +299,7 @@ prepare = [ "hatch run bidi-test:test-cov", ] -[tools.hatch.envs.bidi-lint] +[tool.hatch.envs.bidi-lint] template = "bidi" [tool.hatch.envs.bidi-lint.scripts] From b091f67527ccae60c114fa7bb45bf026fed0e082 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 29 Jan 2026 16:30:09 -0500 Subject: [PATCH 100/279] feat(a2a): add A2AAgent class (#1441) Co-authored-by: Arron Bailiss --- AGENTS.md | 7 +- src/strands/agent/__init__.py | 11 + src/strands/agent/a2a_agent.py | 262 +++++++++++ src/strands/multiagent/a2a/_converters.py | 130 ++++++ src/strands/types/a2a.py | 38 ++ tests/strands/agent/hooks/test_events.py | 2 - tests/strands/agent/test_a2a_agent.py | 414 ++++++++++++++++++ .../strands/multiagent/a2a/test_converters.py | 205 +++++++++ tests_integ/a2a/__init__.py | 0 tests_integ/a2a/a2a_server.py | 15 + tests_integ/a2a/test_multiagent_a2a.py | 72 +++ 11 files changed, 1153 insertions(+), 3 deletions(-) create mode 100644 src/strands/agent/a2a_agent.py create mode 100644 src/strands/multiagent/a2a/_converters.py create mode 100644 src/strands/types/a2a.py create mode 100644 tests/strands/agent/test_a2a_agent.py create mode 100644 tests/strands/multiagent/a2a/test_converters.py create mode 100644 tests_integ/a2a/__init__.py create mode 100644 tests_integ/a2a/a2a_server.py create mode 100644 tests_integ/a2a/test_multiagent_a2a.py diff --git a/AGENTS.md b/AGENTS.md index 71e83835d..a57286941 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,8 @@ strands-agents/ │ ├── agent/ # Core agent implementation │ │ ├── agent.py # Main Agent class │ │ ├── agent_result.py # Agent execution results +│ │ ├── base.py # AgentBase protocol (agent interface) +│ │ ├── a2a_agent.py # A2AAgent client for remote A2A agents │ │ ├── state.py # Agent state management │ │ └── conversation_manager/ # Message history strategies │ │ ├── conversation_manager.py # Base conversation manager @@ -82,7 +84,8 @@ strands-agents/ │ │ ├── swarm.py # Swarm pattern │ │ ├── a2a/ # Agent-to-agent protocol │ │ │ ├── executor.py # A2A executor -│ │ │ └── server.py # A2A server +│ │ │ ├── server.py # A2A server +│ │ │ └── converters.py # Strands/A2A type converters │ │ └── nodes/ # Graph node implementations │ │ │ ├── types/ # Type definitions @@ -102,6 +105,7 @@ strands-agents/ │ │ ├── json_dict.py # JSON dict utilities │ │ ├── collections.py # Collection types │ │ ├── _events.py # Internal event types +│ │ ├── a2a.py # A2A protocol types │ │ └── models/ # Model-specific types │ │ │ ├── session/ # Session management @@ -188,6 +192,7 @@ strands-agents/ │ ├── interrupts/ # Interrupt tests │ ├── steering/ # Steering tests │ ├── bidi/ # Bidirectional streaming tests +│ ├── a2a/ # A2A agent integration tests │ ├── test_multiagent_graph.py │ ├── test_multiagent_swarm.py │ ├── test_stream_agent.py diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 2e40866a9..c901e800f 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -7,6 +7,8 @@ - Retry Strategies: Configurable retry behavior for model calls """ +from typing import Any + from ..event_loop._retry import ModelRetryStrategy from .agent import Agent from .agent_result import AgentResult @@ -28,3 +30,12 @@ "SummarizingConversationManager", "ModelRetryStrategy", ] + + +def __getattr__(name: str) -> Any: + """Lazy load A2AAgent to defer import of optional a2a dependency.""" + if name == "A2AAgent": + from .a2a_agent import A2AAgent + + return A2AAgent + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py new file mode 100644 index 000000000..e18da2f4a --- /dev/null +++ b/src/strands/agent/a2a_agent.py @@ -0,0 +1,262 @@ +"""A2A Agent client for Strands Agents. + +This module provides the A2AAgent class, which acts as a client wrapper for remote A2A agents, +allowing them to be used standalone or as part of multi-agent patterns. + +A2AAgent can be used to get the Agent Card and interact with the agent. +""" + +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import httpx +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent + +from .._async import run_async +from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result +from ..types._events import AgentResultEvent +from ..types.a2a import A2AResponse, A2AStreamEvent +from ..types.agent import AgentInput +from .agent_result import AgentResult +from .base import AgentBase + +logger = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT = 300 + + +class A2AAgent(AgentBase): + """Client wrapper for remote A2A agents.""" + + def __init__( + self, + endpoint: str, + *, + name: str | None = None, + description: str | None = None, + timeout: int = _DEFAULT_TIMEOUT, + a2a_client_factory: ClientFactory | None = None, + ): + """Initialize A2A agent. + + Args: + endpoint: The base URL of the remote A2A agent. + name: Agent name. If not provided, will be populated from agent card. + description: Agent description. If not provided, will be populated from agent card. + timeout: Timeout for HTTP operations in seconds (defaults to 300). + a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, + it will be used to create the A2A client after discovering the agent card. + Note: When providing a custom factory, you are responsible for managing + the lifecycle of any httpx client it uses. + """ + self.endpoint = endpoint + self.name = name + self.description = description + self.timeout = timeout + self._agent_card: AgentCard | None = None + self._a2a_client_factory: ClientFactory | None = a2a_client_factory + + def __call__( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + return run_async(lambda: self.invoke_async(prompt, **kwargs)) + + async def invoke_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the remote A2A agent. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Returns: + AgentResult containing the agent's response. + + Raises: + ValueError: If prompt is None. + RuntimeError: If no response received from agent. + """ + result: AgentResult | None = None + async for event in self.stream_async(prompt, **kwargs): + if "result" in event: + result = event["result"] + + if result is None: + raise RuntimeError("No response received from A2A agent") + + return result + + async def stream_async( + self, + prompt: AgentInput = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream remote agent execution asynchronously. + + This method provides an asynchronous interface for streaming A2A protocol events. + Unlike Agent.stream_async() which yields text deltas and tool events, this method + yields raw A2A protocol events wrapped in A2AStreamEvent dictionaries. + + Args: + prompt: Input to the agent (string, message list, or content blocks). + **kwargs: Additional arguments (ignored). + + Yields: + An async iterator that yields events. Each event is a dictionary: + - A2AStreamEvent: {"type": "a2a_stream", "event": } + where the A2A object can be a Message, or a tuple of + (Task, TaskStatusUpdateEvent) or (Task, TaskArtifactUpdateEvent). + - AgentResultEvent: {"result": AgentResult} - always emitted last. + + Raises: + ValueError: If prompt is None. + + Example: + ```python + async for event in a2a_agent.stream_async("Hello"): + if event.get("type") == "a2a_stream": + print(f"A2A event: {event['event']}") + elif "result" in event: + print(f"Final result: {event['result'].message}") + ``` + """ + last_event = None + last_complete_event = None + + async for event in self._send_message(prompt): + last_event = event + if self._is_complete_event(event): + last_complete_event = event + yield A2AStreamEvent(event) + + # Use the last complete event if available, otherwise fall back to last event + final_event = last_complete_event or last_event + + if final_event is not None: + result = convert_response_to_agent_result(final_event) + yield AgentResultEvent(result) + + async def get_agent_card(self) -> AgentCard: + """Fetch and return the remote agent's card. + + This method eagerly fetches the agent card from the remote endpoint, + populating name and description if not already set. The card is cached + after the first fetch. + + Returns: + The remote agent's AgentCard containing name, description, capabilities, skills, etc. + """ + if self._agent_card is not None: + return self._agent_card + + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() + + # Populate name from card if not set + if self.name is None and self._agent_card.name: + self.name = self._agent_card.name + + # Populate description from card if not set + if self.description is None and self._agent_card.description: + self.description = self._agent_card.description + + logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) + return self._agent_card + + @asynccontextmanager + async def _get_a2a_client(self) -> AsyncIterator[Any]: + """Get A2A client for sending messages. + + If a custom factory was provided, uses that (caller manages httpx lifecycle). + Otherwise creates a per-call httpx client with proper cleanup. + + Yields: + Configured A2A client instance. + """ + agent_card = await self.get_agent_card() + + if self._a2a_client_factory is not None: + yield self._a2a_client_factory.create(agent_card) + return + + async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: + config = ClientConfig(httpx_client=httpx_client, streaming=True) + yield ClientFactory(config).create(agent_card) + + async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: + """Send message to A2A agent. + + Args: + prompt: Input to send to the agent. + + Yields: + A2A response events. + + Raises: + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("prompt is required for A2AAgent") + + message = convert_input_to_message(prompt) + logger.debug("agent=<%s>, endpoint=<%s> | sending message", self.name, self.endpoint) + + async with self._get_a2a_client() as client: + async for event in client.send_message(message): + yield event + + def _is_complete_event(self, event: A2AResponse) -> bool: + """Check if an A2A event represents a complete response. + + Args: + event: A2A event. + + Returns: + True if the event represents a complete response. + """ + # Direct Message is always complete + if isinstance(event, Message): + return True + + # Handle tuple responses (Task, UpdateEvent | None) + if isinstance(event, tuple) and len(event) == 2: + task, update_event = event + + # Initial task response (no update event) + if update_event is None: + return True + + # Artifact update with last_chunk flag + if isinstance(update_event, TaskArtifactUpdateEvent): + if hasattr(update_event, "last_chunk") and update_event.last_chunk is not None: + return update_event.last_chunk + return False + + # Status update with completed state + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state == TaskState.completed + + return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py new file mode 100644 index 000000000..b818c824b --- /dev/null +++ b/src/strands/multiagent/a2a/_converters.py @@ -0,0 +1,130 @@ +"""Conversion functions between Strands and A2A types.""" + +from typing import cast +from uuid import uuid4 + +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from ...agent.agent_result import AgentResult +from ...telemetry.metrics import EventLoopMetrics +from ...types.a2a import A2AResponse +from ...types.agent import AgentInput +from ...types.content import ContentBlock, Message + + +def convert_input_to_message(prompt: AgentInput) -> A2AMessage: + """Convert AgentInput to A2A Message. + + Args: + prompt: Input in various formats (string, message list, or content blocks). + + Returns: + A2AMessage ready to send to the remote agent. + + Raises: + ValueError: If prompt format is unsupported. + """ + message_id = uuid4().hex + + if isinstance(prompt, str): + return A2AMessage( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=prompt))], + message_id=message_id, + ) + + if isinstance(prompt, list) and prompt and (isinstance(prompt[0], dict)): + # Check for interrupt responses - not supported in A2A + if "interruptResponse" in prompt[0]: + raise ValueError("InterruptResponseContent is not supported for A2AAgent") + + if "role" in prompt[0]: + for msg in reversed(prompt): + if msg.get("role") == "user": + content = cast(list[ContentBlock], msg.get("content", [])) + parts = convert_content_blocks_to_parts(content) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + else: + parts = convert_content_blocks_to_parts(cast(list[ContentBlock], prompt)) + return A2AMessage( + kind="message", + role=Role.user, + parts=parts, + message_id=message_id, + ) + + raise ValueError(f"Unsupported input type: {type(prompt)}") + + +def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[Part]: + """Convert Strands ContentBlocks to A2A Parts. + + Args: + content_blocks: List of Strands content blocks. + + Returns: + List of A2A Part objects. + """ + parts = [] + for block in content_blocks: + if "text" in block: + parts.append(Part(TextPart(kind="text", text=block["text"]))) + return parts + + +def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: + """Convert A2A response to AgentResult. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + AgentResult with extracted content and metadata. + """ + content: list[ContentBlock] = [] + + if isinstance(response, tuple) and len(response) == 2: + task, update_event = response + + # Handle artifact updates + if isinstance(update_event, TaskArtifactUpdateEvent): + if update_event.artifact and hasattr(update_event.artifact, "parts"): + for part in update_event.artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle status updates with messages + elif isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + for part in update_event.status.message.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + # Handle initial task or task without update event + elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + for artifact in task.artifacts: + if hasattr(artifact, "parts"): + for part in artifact.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + elif isinstance(response, A2AMessage): + for part in response.parts: + if hasattr(part, "root") and hasattr(part.root, "text"): + content.append({"text": part.root.text}) + + message: Message = { + "role": "assistant", + "content": content, + } + + return AgentResult( + stop_reason="end_turn", + message=message, + metrics=EventLoopMetrics(), + state={}, + ) diff --git a/src/strands/types/a2a.py b/src/strands/types/a2a.py new file mode 100644 index 000000000..2ca444cb0 --- /dev/null +++ b/src/strands/types/a2a.py @@ -0,0 +1,38 @@ +"""Additional A2A types.""" + +from typing import Any, TypeAlias + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + +from ._events import TypedEvent + +A2AResponse: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message | Any + + +class A2AStreamEvent(TypedEvent): + """Event emitted for every update received from the remote A2A server. + + This event wraps all A2A response types during streaming, including: + - Partial task updates (TaskArtifactUpdateEvent) + - Status updates (TaskStatusUpdateEvent) + - Complete messages (Message) + - Final task completions + + The event is emitted for EVERY update from the server, regardless of whether + it represents a complete or partial response. When streaming completes, an + AgentResultEvent containing the final AgentResult is also emitted after all + A2AStreamEvents. + """ + + def __init__(self, a2a_event: A2AResponse) -> None: + """Initialize with A2A event. + + Args: + a2a_event: The original A2A event (Task tuple or Message) + """ + super().__init__( + { + "type": "a2a_stream", + "event": a2a_event, # Nest A2A event to avoid field conflicts + } + ) diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 762b77452..de551d137 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -206,8 +206,6 @@ def test_invocation_state_is_available_in_model_call_events(agent): assert after_event.invocation_state["request_id"] == "req-456" - - def test_before_invocation_event_messages_default_none(agent): """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" event = BeforeInvocationEvent(agent=agent) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py new file mode 100644 index 000000000..26a34476d --- /dev/null +++ b/tests/strands/agent/test_a2a_agent.py @@ -0,0 +1,414 @@ +"""Tests for A2AAgent class.""" + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from a2a.types import AgentCard, Message, Part, Role, TextPart + +from strands.agent.a2a_agent import A2AAgent +from strands.agent.agent_result import AgentResult + + +@pytest.fixture +def mock_agent_card(): + """Mock AgentCard for testing.""" + return AgentCard( + name="test-agent", + description="Test agent", + url="http://localhost:8000", + version="1.0.0", + capabilities={}, + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[], + ) + + +@pytest.fixture +def a2a_agent(): + """Create A2AAgent instance for testing.""" + return A2AAgent(endpoint="http://localhost:8000") + + +@pytest.fixture +def mock_httpx_client(): + """Create a mock httpx.AsyncClient that works as async context manager.""" + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + return mock_client + + +@asynccontextmanager +async def mock_a2a_client_context(send_message_func): + """Helper to create mock A2A client setup for _send_message tests.""" + mock_client = MagicMock() + mock_client.send_message = send_message_func + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_client + mock_factory_class.return_value = mock_factory + yield mock_httpx_class, mock_factory_class + + +def test_init_with_defaults(): + """Test initialization with default parameters.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert agent.endpoint == "http://localhost:8000" + assert agent.timeout == 300 + assert agent._agent_card is None + assert agent.name is None + assert agent.description is None + + +def test_init_with_name_and_description(): + """Test initialization with custom name and description.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="my-agent", description="My custom agent") + assert agent.name == "my-agent" + assert agent.description == "My custom agent" + + +def test_init_with_custom_timeout(): + """Test initialization with custom timeout.""" + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + assert agent.timeout == 600 + + +def test_init_with_external_a2a_client_factory(): + """Test initialization with external A2A client factory.""" + external_factory = MagicMock() + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + + +@pytest.mark.asyncio +async def test_get_agent_card(a2a_agent, mock_agent_card, mock_httpx_client): + """Test agent card discovery.""" + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + assert a2a_agent._agent_card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_cached(a2a_agent, mock_agent_card): + """Test that agent card is cached after first discovery.""" + a2a_agent._agent_card = mock_agent_card + + card = await a2a_agent.get_agent_card() + + assert card == mock_agent_card + + +@pytest.mark.asyncio +async def test_get_agent_card_populates_name_and_description(mock_agent_card, mock_httpx_client): + """Test that agent card populates name and description if not set.""" + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == mock_agent_card.name + assert agent.description == mock_agent_card.description + + +@pytest.mark.asyncio +async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_card, mock_httpx_client): + """Test that custom name and description are not overridden by agent card.""" + agent = A2AAgent(endpoint="http://localhost:8000", name="custom-name", description="Custom description") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_agent_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + assert agent.name == "custom-name" + assert agent.description == "Custom description" + + +@pytest.mark.asyncio +async def test_invoke_async_success(a2a_agent, mock_agent_card): + """Test successful async invocation.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + result = await a2a_agent.invoke_async("Hello") + + assert isinstance(result, AgentResult) + assert result.message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_invoke_async_no_prompt(a2a_agent): + """Test that invoke_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + await a2a_agent.invoke_async(None) + + +@pytest.mark.asyncio +async def test_invoke_async_no_response(a2a_agent, mock_agent_card): + """Test that invoke_async raises RuntimeError when no response received.""" + + async def mock_send_message(*args, **kwargs): + return + yield # Make it an async generator + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + with pytest.raises(RuntimeError, match="No response received"): + await a2a_agent.invoke_async("Hello") + + +def test_call_sync(a2a_agent): + """Test synchronous call method.""" + mock_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=MagicMock(), + state={}, + ) + + with patch("strands.agent.a2a_agent.run_async") as mock_run_async: + mock_run_async.return_value = mock_result + + result = a2a_agent("Hello") + + assert result == mock_result + mock_run_async.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_async_success(a2a_agent, mock_agent_card): + """Test successful async streaming.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + assert len(events) == 2 + # First event is A2A stream event + assert events[0]["type"] == "a2a_stream" + assert events[0]["event"] == mock_response + # Final event is AgentResult + assert "result" in events[1] + assert isinstance(events[1]["result"], AgentResult) + assert events[1]["result"].message["content"][0]["text"] == "Response" + + +@pytest.mark.asyncio +async def test_stream_async_no_prompt(a2a_agent): + """Test that stream_async raises ValueError when prompt is None.""" + with pytest.raises(ValueError, match="prompt is required"): + async for _ in a2a_agent.stream_async(None): + pass + + +@pytest.mark.asyncio +async def test_send_message_uses_provided_factory(mock_agent_card): + """Test _send_message uses provided factory instead of creating per-call client.""" + external_factory = MagicMock() + mock_a2a_client = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send_message + external_factory.create.return_value = mock_a2a_client + + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + # Consume the async iterator + async for _ in agent._send_message("Hello"): + pass + + external_factory.create.assert_called_once_with(mock_agent_card) + + +@pytest.mark.asyncio +async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): + """Test _send_message creates a fresh httpx client for each call when no factory provided.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): + # Consume the async iterator + async for _ in a2a_agent._send_message("Hello"): + pass + + # Verify httpx client was created with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +def test_is_complete_event_message(a2a_agent): + """Test _is_complete_event returns True for Message.""" + mock_message = MagicMock(spec=Message) + + assert a2a_agent._is_complete_event(mock_message) is True + + +def test_is_complete_event_tuple_with_none_update(a2a_agent): + """Test _is_complete_event returns True for tuple with None update event.""" + mock_task = MagicMock() + + assert a2a_agent._is_complete_event((mock_task, None)) is True + + +def test_is_complete_event_artifact_last_chunk(a2a_agent): + """Test _is_complete_event handles TaskArtifactUpdateEvent last_chunk flag.""" + from a2a.types import TaskArtifactUpdateEvent + + mock_task = MagicMock() + + # last_chunk=True -> complete + event_complete = MagicMock(spec=TaskArtifactUpdateEvent) + event_complete.last_chunk = True + assert a2a_agent._is_complete_event((mock_task, event_complete)) is True + + # last_chunk=False -> not complete + event_incomplete = MagicMock(spec=TaskArtifactUpdateEvent) + event_incomplete.last_chunk = False + assert a2a_agent._is_complete_event((mock_task, event_incomplete)) is False + + # last_chunk=None -> not complete + event_none = MagicMock(spec=TaskArtifactUpdateEvent) + event_none.last_chunk = None + assert a2a_agent._is_complete_event((mock_task, event_none)) is False + + +def test_is_complete_event_status_update(a2a_agent): + """Test _is_complete_event handles TaskStatusUpdateEvent state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + + # completed state -> complete + event_completed = MagicMock(spec=TaskStatusUpdateEvent) + event_completed.status = MagicMock() + event_completed.status.state = TaskState.completed + assert a2a_agent._is_complete_event((mock_task, event_completed)) is True + + # working state -> not complete + event_working = MagicMock(spec=TaskStatusUpdateEvent) + event_working.status = MagicMock() + event_working.status.state = TaskState.working + assert a2a_agent._is_complete_event((mock_task, event_working)) is False + + # no status -> not complete + event_no_status = MagicMock(spec=TaskStatusUpdateEvent) + event_no_status.status = None + assert a2a_agent._is_complete_event((mock_task, event_no_status)) is False + + +def test_is_complete_event_unknown_type(a2a_agent): + """Test _is_complete_event returns False for unknown event types.""" + assert a2a_agent._is_complete_event("unknown") is False + + +@pytest.mark.asyncio +async def test_stream_async_tracks_complete_events(a2a_agent, mock_agent_card): + """Test stream_async uses last complete event for final result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + # First event: incomplete + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + # Second event: complete + complete_event = MagicMock(spec=TaskStatusUpdateEvent) + complete_event.status = MagicMock() + complete_event.status.state = TaskState.completed + complete_event.status.message = MagicMock() + complete_event.status.message.parts = [] + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + yield (mock_task, complete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 2 stream events + 1 result event + assert len(events) == 3 + assert "result" in events[2] + + +@pytest.mark.asyncio +async def test_stream_async_falls_back_to_last_event(a2a_agent, mock_agent_card): + """Test stream_async falls back to last event when no complete event.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + mock_task = MagicMock() + mock_task.artifacts = None + + incomplete_event = MagicMock(spec=TaskStatusUpdateEvent) + incomplete_event.status = MagicMock() + incomplete_event.status.state = TaskState.working + incomplete_event.status.message = None + + async def mock_send_message(*args, **kwargs): + yield (mock_task, incomplete_event) + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message): + events = [] + async for event in a2a_agent.stream_async("Hello"): + events.append(event) + + # Should have 1 stream event + 1 result event (falls back to last) + assert len(events) == 2 + assert "result" in events[1] diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py new file mode 100644 index 000000000..002ebf6a6 --- /dev/null +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -0,0 +1,205 @@ +"""Tests for A2A converter functions.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from a2a.types import Message as A2AMessage +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart + +from strands.agent.agent_result import AgentResult +from strands.multiagent.a2a._converters import ( + convert_content_blocks_to_parts, + convert_input_to_message, + convert_response_to_agent_result, +) + + +def test_convert_string_input(): + """Test converting string input to A2A message.""" + message = convert_input_to_message("Hello") + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + assert message.parts[0].root.text == "Hello" + + +def test_convert_message_list_input(): + """Test converting message list input to A2A message.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + message = convert_input_to_message(messages) + + assert isinstance(message, A2AMessage) + assert message.role == Role.user + assert len(message.parts) == 1 + + +def test_convert_content_blocks_input(): + """Test converting content blocks input to A2A message.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + message = convert_input_to_message(content_blocks) + + assert isinstance(message, A2AMessage) + assert len(message.parts) == 2 + + +def test_convert_unsupported_input(): + """Test that unsupported input types raise ValueError.""" + with pytest.raises(ValueError, match="Unsupported input type"): + convert_input_to_message(123) + + +def test_convert_interrupt_response_raises_error(): + """Test that InterruptResponseContent raises explicit error.""" + interrupt_responses = [{"interruptResponse": {"interruptId": "123", "response": "A"}}] + + with pytest.raises(ValueError, match="InterruptResponseContent is not supported for A2AAgent"): + convert_input_to_message(interrupt_responses) + + +def test_convert_content_blocks_to_parts(): + """Test converting content blocks to A2A parts.""" + content_blocks = [{"text": "Hello"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + assert parts[0].root.text == "Hello" + assert parts[1].root.text == "World" + + +def test_convert_a2a_message_response(): + """Test converting A2A message response to AgentResult.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert isinstance(result, AgentResult) + assert result.message["role"] == "assistant" + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Response" + + +def test_convert_task_response(): + """Test converting task response to AgentResult.""" + mock_task = MagicMock() + mock_artifact = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Task response" + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + result = convert_response_to_agent_result((mock_task, None)) + + assert isinstance(result, AgentResult) + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Task response" + + +def test_convert_multiple_parts_response(): + """Test converting response with multiple parts to separate content blocks.""" + a2a_message = A2AMessage( + message_id=uuid4().hex, + role=Role.agent, + parts=[ + Part(TextPart(kind="text", text="First")), + Part(TextPart(kind="text", text="Second")), + ], + ) + + result = convert_response_to_agent_result(a2a_message) + + assert len(result.message["content"]) == 2 + assert result.message["content"][0]["text"] == "First" + assert result.message["content"][1]["text"] == "Second" + + +# --- New tests for coverage --- + + +def test_convert_message_list_finds_last_user_message(): + """Test that message list conversion finds the last user message.""" + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + ] + + message = convert_input_to_message(messages) + + assert message.parts[0].root.text == "Second" + + +def test_convert_content_blocks_skips_non_text(): + """Test that non-text content blocks are skipped.""" + content_blocks = [{"text": "Hello"}, {"image": "data"}, {"text": "World"}] + + parts = convert_content_blocks_to_parts(content_blocks) + + assert len(parts) == 2 + + +def test_convert_task_artifact_update_event(): + """Test converting TaskArtifactUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Streamed artifact" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = mock_artifact + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Streamed artifact" + + +def test_convert_task_status_update_event(): + """Test converting TaskStatusUpdateEvent to AgentResult.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Status message" + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_status = MagicMock() + mock_status.message = mock_message + + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = mock_status + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert result.message["content"][0]["text"] == "Status message" + + +def test_convert_response_handles_missing_data(): + """Test that response conversion handles missing/malformed data gracefully.""" + # TaskArtifactUpdateEvent with no artifact + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event.artifact = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # TaskStatusUpdateEvent with no status + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_event.status = None + result = convert_response_to_agent_result((MagicMock(), mock_event)) + assert len(result.message["content"]) == 0 + + # Task artifact without parts attribute + mock_task = MagicMock() + mock_artifact = MagicMock(spec=[]) + del mock_artifact.parts + mock_task.artifacts = [mock_artifact] + result = convert_response_to_agent_result((mock_task, None)) + assert len(result.message["content"]) == 0 diff --git a/tests_integ/a2a/__init__.py b/tests_integ/a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/a2a/a2a_server.py b/tests_integ/a2a/a2a_server.py new file mode 100644 index 000000000..047edc3ba --- /dev/null +++ b/tests_integ/a2a/a2a_server.py @@ -0,0 +1,15 @@ +from strands import Agent +from strands.multiagent.a2a import A2AServer + +# Create an agent and serve it over A2A +agent = Agent( + name="Test agent", + description="Test description here", + callback_handler=None, +) +a2a_server = A2AServer( + agent=agent, + host="localhost", + port=9000, +) +a2a_server.serve() diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py new file mode 100644 index 000000000..60cbc9ce5 --- /dev/null +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -0,0 +1,72 @@ +import os +import subprocess +import time + +import httpx +import pytest +from a2a.client import ClientConfig, ClientFactory + +from strands.agent.a2a_agent import A2AAgent + + +@pytest.fixture +def a2a_server(): + """Start A2A server as subprocess fixture.""" + server_path = os.path.join(os.path.dirname(__file__), "a2a_server.py") + process = subprocess.Popen(["python", server_path]) + time.sleep(5) # Wait for A2A server to start + + yield "http://localhost:9000" + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + +def test_a2a_agent_invoke_sync(a2a_server): + """Test synchronous invocation via __call__.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = a2a_agent("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_invoke_async(a2a_server): + """Test async invocation.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_stream_async(a2a_server): + """Test async streaming.""" + a2a_agent = A2AAgent(endpoint=a2a_server) + + events = [] + async for event in a2a_agent.stream_async("Hello there!"): + events.append(event) + + # Should have at least one A2A stream event and one final result event + assert len(events) >= 2 + assert events[0]["type"] == "a2a_stream" + assert "result" in events[-1] + assert events[-1]["result"].stop_reason == "end_turn" + + +@pytest.mark.asyncio +async def test_a2a_agent_with_non_streaming_client_config(a2a_server): + """Test with streaming=False client configuration (non-default).""" + httpx_client = httpx.AsyncClient(timeout=300) + config = ClientConfig(httpx_client=httpx_client, streaming=False) + factory = ClientFactory(config) + + try: + a2a_agent = A2AAgent(endpoint=a2a_server, a2a_client_factory=factory) + result = await a2a_agent.invoke_async("Hello there!") + assert result.stop_reason == "end_turn" + finally: + await httpx_client.aclose() From 53db63eb00e722447cb9fd297cdd43b9ba6e1437 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Thu, 29 Jan 2026 15:40:44 -0600 Subject: [PATCH 101/279] fix(tools): preserve nullable semantics for required Union[T, None] params (#1584) Co-authored-by: Dean Schmigelski --- src/strands/tools/decorator.py | 11 +++- tests/strands/tools/test_decorator.py | 78 +++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index f72a8ccf1..04c14e452 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -326,13 +326,20 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: del schema[key] # Process properties to clean up anyOf and similar structures + required_fields = schema.get("required", []) if "properties" in schema: - for _prop_name, prop_schema in schema["properties"].items(): + for prop_name, prop_schema in schema["properties"].items(): # Handle anyOf constructs (common for Optional types) if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): + # Only simplify when the field is not required; required nullable + # fields need anyOf preserved so the model can pass null. + if ( + prop_name not in required_fields + and len(any_of) == 2 + and any(item.get("type") == "null" for item in any_of) + ): # Find the non-null type for item in any_of: if item.get("type") != "null": diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 4757e5587..42213fcb8 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1823,3 +1823,81 @@ def test_tool_decorator_annotated_field_with_inner_default(): @strands.tool def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: return f"{name} is at level {level}" + + +def test_tool_nullable_required_field_preserves_anyof(): + """Test that a required nullable field preserves anyOf so the model can pass null. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1525 + """ + from enum import Enum + + class Priority(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + @strands.tool + def prioritized_task(description: str, priority: Priority | None) -> str: + """Create a task with optional priority. + + Args: + description: Task description + priority: Optional priority level + """ + return f"{description}: {priority}" + + spec = prioritized_task.tool_spec + schema = spec["inputSchema"]["json"] + + expected_schema = { + "$defs": { + "Priority": { + "enum": ["high", "medium", "low"], + "title": "Priority", + "type": "string", + }, + }, + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Task description", + }, + "priority": { + "anyOf": [ + {"$ref": "#/$defs/Priority"}, + {"type": "null"}, + ], + "description": "Optional priority level", + }, + }, + "required": ["description", "priority"], + } + + assert schema == expected_schema + + +def test_tool_nullable_optional_field_simplifies_anyof(): + """Test that a non-required nullable field still gets anyOf simplified.""" + + @strands.tool + def my_tool(name: str, tag: str | None = None) -> str: + """A tool. + + Args: + name: The name + tag: An optional tag + """ + return f"{name}: {tag}" + + spec = my_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # tag has a default, so it should NOT be required + assert "name" in schema["required"] + assert "tag" not in schema["required"] + + # Since tag is not required, anyOf should be simplified away + assert "anyOf" not in schema["properties"]["tag"] + assert schema["properties"]["tag"]["type"] == "string" From 40c4ebb74aa7110de405c93ca735970fa7c5affe Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 29 Jan 2026 18:47:14 -0500 Subject: [PATCH 102/279] Feature: Allow s3Location as Document, Image, and Video location source (#1572) --- src/strands/models/bedrock.py | 61 +++++-- src/strands/types/media.py | 54 ++++++- tests/strands/models/test_bedrock.py | 151 +++++++++++++++++- tests/strands/tools/mcp/test_mcp_client.py | 2 +- tests/strands/types/test_media.py | 99 ++++++++++++ tests_integ/conftest.py | 11 +- tests_integ/mcp/echo_server.py | 2 +- tests_integ/mcp/test_mcp_client.py | 2 +- tests_integ/resources/blue.mp4 | Bin 0 -> 5200 bytes tests_integ/{ => resources}/letter.pdf | Bin tests_integ/{ => resources}/yellow.png | Bin tests_integ/test_a2a_executor.py | 4 +- tests_integ/test_bedrock_s3_location.py | 177 +++++++++++++++++++++ 13 files changed, 531 insertions(+), 32 deletions(-) create mode 100644 tests/strands/types/test_media.py create mode 100644 tests_integ/resources/blue.mp4 rename tests_integ/{ => resources}/letter.pdf (100%) rename tests_integ/{ => resources}/yellow.png (100%) create mode 100644 tests_integ/test_bedrock_s3_location.py diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index a3cea7cfe..b053b70fb 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -17,6 +17,8 @@ from pydantic import BaseModel from typing_extensions import TypedDict, Unpack, override +from strands.types.media import S3Location, SourceLocation + from .._exception_notes import add_exception_note from ..event_loop import streaming from ..tools import convert_pydantic_to_tool_spec @@ -407,6 +409,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Format content blocks for Bedrock API compatibility formatted_content = self._format_request_message_content(content_block) + if formatted_content is None: + continue # Wrap text or image content in guardrailContent if this is the last user message if ( @@ -459,7 +463,19 @@ def _should_include_tool_result_status(self) -> bool: else: # "auto" return any(model in self.config["model_id"] for model in _MODELS_INCLUDE_STATUS) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None: + """Convert location content block to Bedrock format if its an S3Location.""" + if location["type"] == "s3": + s3_location = cast(S3Location, location) + formatted_document_s3: dict[str, Any] = {"uri": s3_location["uri"]} + if "bucketOwner" in s3_location: + formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"] + return {"s3Location": formatted_document_s3} + else: + logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block") + return None + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None: """Format a Bedrock content block. Bedrock strictly validates content blocks and throws exceptions for unknown fields. @@ -489,9 +505,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "format" in document: result["format"] = document["format"] - # Handle source + # Handle source - supports bytes or location if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + source = document["source"] + formatted_document_source: dict[str, Any] | None + if "location" in source: + formatted_document_source = self._handle_location(source["location"]) + if formatted_document_source is None: + return None + elif "bytes" in source: + formatted_document_source = {"bytes": source["bytes"]} + result["source"] = formatted_document_source # Handle optional fields if "citations" in document and document["citations"] is not None: @@ -512,10 +536,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "image" in content: image = content["image"] source = image["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": image["format"], "source": formatted_source} + formatted_image_source: dict[str, Any] | None + if "location" in source: + formatted_image_source = self._handle_location(source["location"]) + if formatted_image_source is None: + return None + elif "bytes" in source: + formatted_image_source = {"bytes": source["bytes"]} + result = {"format": image["format"], "source": formatted_image_source} return {"image": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html @@ -550,9 +578,12 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # Handle json field since not in ContentBlock but valid in ToolResultContent formatted_content.append({"json": tool_result_content["json"]}) else: - formatted_content.append( - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + formatted_message_content = self._format_request_message_content( + cast(ContentBlock, tool_result_content) ) + if formatted_message_content is None: + continue + formatted_content.append(formatted_message_content) result = { "content": formatted_content, @@ -577,10 +608,14 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "video" in content: video = content["video"] source = video["source"] - formatted_source = {} - if "bytes" in source: - formatted_source = {"bytes": source["bytes"]} - result = {"format": video["format"], "source": formatted_source} + formatted_video_source: dict[str, Any] | None + if "location" in source: + formatted_video_source = self._handle_location(source["location"]) + if formatted_video_source is None: + return None + elif "bytes" in source: + formatted_video_source = {"bytes": source["bytes"]} + result = {"format": video["format"], "source": formatted_video_source} return {"video": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 462d8af34..b1240dffb 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,9 +5,9 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Literal, TypeAlias -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict from .citations import CitationsConfig @@ -15,14 +15,50 @@ """Supported document formats.""" -class DocumentSource(TypedDict): +class Location(TypedDict, total=False): + """A location for a document. + + This type is a generic location for a document. Its usage is determined by the underlying model provider. + """ + + type: Required[str] + + +class S3Location(Location, total=False): + """A storage location in an Amazon S3 bucket. + + Used by Bedrock to reference media files stored in S3 instead of passing raw bytes. + + - Docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_S3Location.html + + Attributes: + type: s3 + uri: An object URI starting with `s3://`. Required. + bucketOwner: If the bucket belongs to another AWS account, specify that account's ID. Optional. + """ + + # mypy doesn't like overriding this field since its a subclass, but since its just a literal string, this is fine. + + type: Literal["s3"] # type: ignore[misc] + uri: Required[str] + bucketOwner: str + + +SourceLocation: TypeAlias = Location | S3Location + + +class DocumentSource(TypedDict, total=False): """Contains the content of a document. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the document. + location: Location of the document. """ bytes: bytes + location: SourceLocation class DocumentContent(TypedDict, total=False): @@ -45,14 +81,18 @@ class DocumentContent(TypedDict, total=False): """Supported image formats.""" -class ImageSource(TypedDict): +class ImageSource(TypedDict, total=False): """Contains the content of an image. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the image. + location: Location of the image. """ bytes: bytes + location: SourceLocation class ImageContent(TypedDict): @@ -71,14 +111,18 @@ class ImageContent(TypedDict): """Supported video formats.""" -class VideoSource(TypedDict): +class VideoSource(TypedDict, total=False): """Contains the content of a video. + Only one of `bytes` or `s3Location` should be specified. + Attributes: bytes: The binary content of the video. + location: Location of the video. """ bytes: bytes + location: SourceLocation class VideoContent(TypedDict): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e92018f35..761434258 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,5 @@ +import copy +import logging import os import sys import traceback @@ -1519,7 +1521,6 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model @pytest.mark.asyncio async def test_stream_logging(bedrock_client, model, messages, caplog, alist): """Test that stream method logs debug messages at the expected stages.""" - import logging # Set the logger to debug level to capture debug messages caplog.set_level(logging.DEBUG, logger="strands.models.bedrock") @@ -1787,8 +1788,8 @@ def test_format_request_filters_image_content_blocks(model, model_id): assert "metadata" not in image_block -def test_format_request_filters_nested_image_s3_fields(model, model_id): - """Test that s3Location is filtered out and only bytes source is preserved.""" +def test_format_request_image_s3_location_only(model, model_id): + """Test that image with only s3Location is properly formatted.""" messages = [ { "role": "user", @@ -1797,8 +1798,7 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): "image": { "format": "png", "source": { - "bytes": b"image_data", - "s3Location": {"bucket": "my-bucket", "key": "image.png", "extraField": "filtered"}, + "location": {"type": "s3", "uri": "s3://my-bucket/image.png"}, }, } } @@ -1809,8 +1809,146 @@ def test_format_request_filters_nested_image_s3_fields(model, model_id): formatted_request = model._format_request(messages) image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + assert image_source == {"s3Location": {"uri": "s3://my-bucket/image.png"}} + + +def test_format_request_image_bytes_only(model, model_id): + """Test that image with only bytes source is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "image": { + "format": "png", + "source": {"bytes": b"image_data"}, + } + } + ], + } + ] + + formatted_request = model._format_request(messages) + image_source = formatted_request["messages"][0]["content"][0]["image"]["source"] + assert image_source == {"bytes": b"image_data"} - assert "s3Location" not in image_source + + +def test_format_request_document_s3_location(model, model_id): + """Test that document with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}, + }, + } + }, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "s3", + "uri": "s3://my-bucket/report.pdf", + "bucketOwner": "123456789012", + }, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + document = formatted_request["messages"][0]["content"][0]["document"] + document_with_bucket_owner = formatted_request["messages"][0]["content"][1]["document"] + + assert document["source"] == {"s3Location": {"uri": "s3://my-bucket/report.pdf"}} + + assert document_with_bucket_owner["source"] == { + "s3Location": {"uri": "s3://my-bucket/report.pdf", "bucketOwner": "123456789012"} + } + + +def test_format_request_unsupported_location(model, caplog): + """Test that document with s3Location is properly formatted.""" + + caplog.set_level(logging.WARNING, logger="strands.models.bedrock") + + messages = [ + { + "role": "user", + "content": [ + {"text": "Hello!"}, + { + "document": { + "name": "report.pdf", + "format": "pdf", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "video": { + "format": "mp4", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + { + "image": { + "format": "png", + "source": { + "location": { + "type": "other", + }, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + assert len(formatted_request["messages"][0]["content"]) == 1 + assert "Non s3 location sources are not supported by Bedrock, skipping content block" in caplog.text + + +def test_format_request_video_s3_location(model, model_id): + """Test that video with s3Location is properly formatted.""" + messages = [ + { + "role": "user", + "content": [ + { + "video": { + "format": "mp4", + "source": { + "location": {"type": "s3", "uri": "s3://my-bucket/video.mp4"}, + }, + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + video_source = formatted_request["messages"][0]["content"][0]["video"]["source"] + + assert video_source == {"s3Location": {"uri": "s3://my-bucket/video.mp4"}} def test_format_request_filters_document_content_blocks(model, model_id): @@ -2310,7 +2448,6 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client): def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): """Test that _format_bedrock_messages does not mutate original messages.""" - import copy model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f784da414..a2ef369ea 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -632,7 +632,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): """EmbeddedResource.resource (blob with image MIME) should map to image content.""" # Read yellow.png file - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() payload = base64.b64encode(png_data).decode() diff --git a/tests/strands/types/test_media.py b/tests/strands/types/test_media.py new file mode 100644 index 000000000..2fa8c3621 --- /dev/null +++ b/tests/strands/types/test_media.py @@ -0,0 +1,99 @@ +"""Tests for media type definitions.""" + +from strands.types.media import ( + DocumentSource, + ImageSource, + S3Location, + VideoSource, +) + + +class TestS3Location: + """Tests for S3Location TypedDict.""" + + def test_s3_location_with_uri_only(self): + """Test S3Location with only uri field.""" + s3_loc: S3Location = {"uri": "s3://my-bucket/path/to/file.pdf"} + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert "bucketOwner" not in s3_loc + + def test_s3_location_with_bucket_owner(self): + """Test S3Location with both uri and bucketOwner fields.""" + s3_loc: S3Location = { + "uri": "s3://my-bucket/path/to/file.pdf", + "bucketOwner": "123456789012", + } + + assert s3_loc["uri"] == "s3://my-bucket/path/to/file.pdf" + assert s3_loc["bucketOwner"] == "123456789012" + + +class TestDocumentSource: + """Tests for DocumentSource TypedDict.""" + + def test_document_source_with_bytes(self): + """Test DocumentSource with bytes content.""" + doc_source: DocumentSource = {"bytes": b"document content"} + + assert doc_source["bytes"] == b"document content" + assert "s3Location" not in doc_source + + def test_document_source_with_s3_location(self): + """Test DocumentSource with s3Location.""" + doc_source: DocumentSource = { + "s3Location": { + "uri": "s3://my-bucket/docs/report.pdf", + "bucketOwner": "123456789012", + } + } + + assert "bytes" not in doc_source + assert doc_source["s3Location"]["uri"] == "s3://my-bucket/docs/report.pdf" + assert doc_source["s3Location"]["bucketOwner"] == "123456789012" + + +class TestImageSource: + """Tests for ImageSource TypedDict.""" + + def test_image_source_with_bytes(self): + """Test ImageSource with bytes content.""" + img_source: ImageSource = {"bytes": b"image content"} + + assert img_source["bytes"] == b"image content" + assert "s3Location" not in img_source + + def test_image_source_with_s3_location(self): + """Test ImageSource with s3Location.""" + img_source: ImageSource = { + "s3Location": { + "uri": "s3://my-bucket/images/photo.png", + } + } + + assert "bytes" not in img_source + assert img_source["s3Location"]["uri"] == "s3://my-bucket/images/photo.png" + + +class TestVideoSource: + """Tests for VideoSource TypedDict.""" + + def test_video_source_with_bytes(self): + """Test VideoSource with bytes content.""" + vid_source: VideoSource = {"bytes": b"video content"} + + assert vid_source["bytes"] == b"video content" + assert "s3Location" not in vid_source + + def test_video_source_with_s3_location(self): + """Test VideoSource with s3Location.""" + vid_source: VideoSource = { + "s3Location": { + "uri": "s3://my-bucket/videos/clip.mp4", + "bucketOwner": "987654321098", + } + } + + assert "bytes" not in vid_source + assert vid_source["s3Location"]["uri"] == "s3://my-bucket/videos/clip.mp4" + assert vid_source["s3Location"]["bucketOwner"] == "987654321098" diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 9de00089b..dbe25d685 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -133,14 +133,21 @@ def pytest_sessionstart(session): @pytest.fixture def yellow_img(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/yellow.png" + path = pytestconfig.rootdir / "tests_integ/resources/yellow.png" with open(path, "rb") as fp: return fp.read() @pytest.fixture def letter_pdf(pytestconfig): - path = pytestconfig.rootdir / "tests_integ/letter.pdf" + path = pytestconfig.rootdir / "tests_integ/resources/letter.pdf" + with open(path, "rb") as fp: + return fp.read() + + +@pytest.fixture +def blue_video(pytestconfig): + path = pytestconfig.rootdir / "tests_integ/resources/blue.mp4" with open(path, "rb") as fp: return fp.read() diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 8fa1fb2b2..363c588ee 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -90,7 +90,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): ] elif location.lower() == "tokyo": # Read yellow.png file for weather icon - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: png_data = image_file.read() return [ EmbeddedResource( diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 298272df5..4e192c935 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -43,7 +43,7 @@ def calculator(x: int, y: int) -> int: @mcp.tool(description="Generates a custom image") def generate_custom_image() -> MCPImageContent: try: - with open("tests_integ/yellow.png", "rb") as image_file: + with open("tests_integ/resources/yellow.png", "rb") as image_file: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: diff --git a/tests_integ/resources/blue.mp4 b/tests_integ/resources/blue.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5989bb4b02d85ad96d9985acdfafd0125096acb5 GIT binary patch literal 5200 zcmeHL|7#pY6rXFfNo!luCao2MjG7-5lFRO1a;?$DTtaGuh*h*j5Vmu>b9ZZYZ#TP> zJFIi~ccr>QIVoEM2+u^@D5wuCHU~UYGeUP>chg{=D#_Mc%#&GPd2m2oiui`W(Rx z2(2IHg^9ry_3v8OZ~B5C>LE!1-5k-Dj3U|tETA2GxE`JL3D*I)M`wTBA>YRUoCSK2 zu^?x`c@Ta78+yQYII;E3-sbczx|FCl@3$g@i>*UW7NwiibC_5 zv8*)4z%Y{rhmoiEPd_<4N^=LMz|-J57^WPzYVm@giX>%*6-gNbWl0Ekd}L&4X(^3& z498;SwBr>=aFldO*cSLWt}valKTdU)XSym=xJRfNYVf?}=yR$(E{#i+m6=ubxhhpM z<5ESIGt}m4iC3tj9#-xbV;Nk}&^>tq6`hrkLB@EMJxTGHUOVHiZwHwn#yQizV zSD-fBpEynn1XanTB|49jQKfViSQmi<$|`F1QBe4TyXq)4T}Tpa2*@E|v3bZpW|JI+ z9h~DQkCDfkjZ1H@_g~omyqNv`;l}az^=H2Af8F}CZ_mKHjqW33hsY3H{_Bx1W?$R) zU22Gs9c$m5kXBZ|eEizAgU}0MCRF2nY~o0m6zw ztc4I?P48@jxZD=Sl{Sd035YO?`nGn6`fxmo`bZq&^k@PijH3Qr0%ATMMcr?Ms3ahw zD3)6gM`2={Q}qwpqBs{q;L7ynPJf($h@$wu1raV_{hziduD3z_lz<4MsNLU!2&1T} zafsRzafp?{1Vk7`ZL$RsM)AM1^8yOd9e-Xp z&BmEyg#3sfHzf78>&TAW=~f*nHXF5G(p3x*ZhKn*LaU4%Y&P~zkgfQK5yWtNRpdXN CcM9hK literal 0 HcmV?d00001 diff --git a/tests_integ/letter.pdf b/tests_integ/resources/letter.pdf similarity index 100% rename from tests_integ/letter.pdf rename to tests_integ/resources/letter.pdf diff --git a/tests_integ/yellow.png b/tests_integ/resources/yellow.png similarity index 100% rename from tests_integ/yellow.png rename to tests_integ/resources/yellow.png diff --git a/tests_integ/test_a2a_executor.py b/tests_integ/test_a2a_executor.py index ddca0bfa6..43a6026bf 100644 --- a/tests_integ/test_a2a_executor.py +++ b/tests_integ/test_a2a_executor.py @@ -17,7 +17,7 @@ async def test_a2a_executor_with_real_image(): """Test A2A server processes a real image file correctly via HTTP.""" # Read the test image file - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_image_bytes = f.read() @@ -80,7 +80,7 @@ async def test_a2a_executor_with_real_image(): def test_a2a_executor_image_roundtrip(): """Test that image data survives the A2A base64 encoding/decoding roundtrip.""" # Read the test image - test_image_path = os.path.join(os.path.dirname(__file__), "yellow.png") + test_image_path = os.path.join(os.path.dirname(__file__), "resources/yellow.png") with open(test_image_path, "rb") as f: original_bytes = f.read() diff --git a/tests_integ/test_bedrock_s3_location.py b/tests_integ/test_bedrock_s3_location.py new file mode 100644 index 000000000..9b28e88be --- /dev/null +++ b/tests_integ/test_bedrock_s3_location.py @@ -0,0 +1,177 @@ +"""Integration tests for S3 location support in media content types.""" + +import time + +import boto3 +import pytest + +from strands import Agent +from strands.models.bedrock import BedrockModel + + +@pytest.fixture +def boto_session(): + """Create a boto3 session for testing.""" + return boto3.Session(region_name="us-west-2") + + +@pytest.fixture +def account_id(boto_session): + """Get the current AWS account ID.""" + sts_client = boto_session.client("sts") + return sts_client.get_caller_identity()["Account"] + + +@pytest.fixture +def s3_client(boto_session): + """Create an S3 client.""" + return boto_session.client("s3") + + +@pytest.fixture +def test_bucket(s3_client, account_id): + """Create a test S3 bucket for the tests. + + Creates a bucket with account-specific name and cleans it up after tests. + """ + bucket_name = f"strands-integ-tests-resources-{account_id}" + + # Create the bucket if it doesn't exist + try: + s3_client.head_bucket(Bucket=bucket_name) + print(f"Bucket {bucket_name} already exists") + except s3_client.exceptions.ClientError: + try: + s3_client.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + print(f"Created test bucket: {bucket_name}") + # Wait for bucket to be available + time.sleep(2) + except s3_client.exceptions.BucketAlreadyOwnedByYou: + print(f"Bucket {bucket_name} already exists") + + yield bucket_name + + # Note: We don't delete the bucket to allow reuse across test runs + # Objects will be overwritten on subsequent runs + + +@pytest.fixture +def s3_document(s3_client, test_bucket, letter_pdf): + """Upload a test document to S3 and return its URI.""" + document_key = "test-documents/letter.pdf" + + # Upload the document using existing letter_pdf fixture + s3_client.put_object( + Bucket=test_bucket, + Key=document_key, + Body=letter_pdf, + ContentType="application/pdf", + ) + print(f"Uploaded test document to s3://{test_bucket}/{document_key}") + + return f"s3://{test_bucket}/{document_key}" + + +@pytest.fixture +def s3_image(s3_client, test_bucket, yellow_img): + """Upload a test image to S3 and return its URI.""" + image_key = "test-images/yellow.png" + + # Upload the image using existing yellow_img fixture + s3_client.put_object( + Bucket=test_bucket, + Key=image_key, + Body=yellow_img, + ContentType="image/png", + ) + print(f"Uploaded test image to s3://{test_bucket}/{image_key}") + + return f"s3://{test_bucket}/{image_key}" + + +@pytest.fixture +def s3_video(s3_client, test_bucket, blue_video): + """Upload a test video to S3 and return its URI.""" + video_key = "test-videos/blue.mp4" + + # Upload the video using existing blue_video fixture + s3_client.put_object( + Bucket=test_bucket, + Key=video_key, + Body=blue_video, + ContentType="video/mp4", + ) + print(f"Uploaded test video to s3://{test_bucket}/{video_key}") + + return f"s3://{test_bucket}/{video_key}" + + +def test_document_s3_location(s3_document, account_id): + """Test that Bedrock correctly formats a document with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this document?"}, + { + "document": { + "format": "pdf", + "name": "letter", + "source": {"location": {"type": "s3", "uri": s3_document, "bucketOwner": account_id}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_image_s3_location(s3_image): + """Test that Bedrock correctly formats an image with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Please tell me about this image?"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": s3_image}}, + }, + }, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-2-lite-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 + + +def test_video_s3_location(s3_video): + """Test that Bedrock correctly formats a video with S3 location.""" + messages = [ + { + "role": "user", + "content": [ + {"text": "Describe the colors is in this video?"}, + {"video": {"format": "mp4", "source": {"location": {"type": "s3", "uri": s3_video}}}}, + ], + }, + ] + + agent = Agent(model=BedrockModel(model_id="us.amazon.nova-pro-v1:0", region_name="us-west-2")) + result = agent(messages) + + # The actual recognition capabilities of these models is not great, so just asserting that the call actually worked. + assert len(str(result)) > 0 From 00a55d217108477406bff8988fd82e3801b246d8 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Fri, 30 Jan 2026 06:02:42 -0800 Subject: [PATCH 103/279] fix: LedgerProvider handles parallel tool calls (#1559) --- .../context_providers/ledger_provider.py | 26 ++- .../steering/handlers/llm/mappers.py | 16 +- .../context_providers/test_ledger_provider.py | 191 +++++++++++++++++- 3 files changed, 221 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py index 0e7bde529..43f56717a 100644 --- a/src/strands/experimental/steering/context_providers/ledger_provider.py +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -46,6 +46,7 @@ def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext tool_call_entry = { "timestamp": datetime.now().isoformat(), + "tool_use_id": event.tool_use.get("toolUseId"), "tool_name": event.tool_use.get("name"), "tool_args": event.tool_use.get("input", {}), "status": "pending", @@ -62,16 +63,21 @@ def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, ledger = steering_context.data.get("ledger") or {} if ledger.get("tool_calls"): - last_call = ledger["tool_calls"][-1] - last_call.update( - { - "completion_timestamp": datetime.now().isoformat(), - "status": event.result["status"], - "result": event.result["content"], - "error": str(event.exception) if event.exception else None, - } - ) - steering_context.data.set("ledger", ledger) + tool_use_id = event.tool_use.get("toolUseId") + + # Search for the matching tool call in the ledger to update it + for call in reversed(ledger["tool_calls"]): + if call.get("tool_use_id") == tool_use_id and call.get("status") == "pending": + call.update( + { + "completion_timestamp": datetime.now().isoformat(), + "status": event.result["status"], + "result": event.result["content"], + "error": str(event.exception) if event.exception else None, + } + ) + steering_context.data.set("ledger", ledger) + break class LedgerProvider(SteeringContextProvider): diff --git a/src/strands/experimental/steering/handlers/llm/mappers.py b/src/strands/experimental/steering/handlers/llm/mappers.py index 9901da7d4..ade018d32 100644 --- a/src/strands/experimental/steering/handlers/llm/mappers.py +++ b/src/strands/experimental/steering/handlers/llm/mappers.py @@ -23,7 +23,7 @@ **CRITICAL CONSTRAINTS:** - Base decisions ONLY on the context data provided below -- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT use external knowledge about domains, URLs, or tool purposes - Do NOT make assumptions about what tools "should" or "shouldn't" do - Focus ONLY on patterns in the context data @@ -31,6 +31,20 @@ {context_str} +### Understanding Ledger Tool States + +If the context includes a ledger with tool_calls, the "status" field indicates: + +- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). +This is NOT a duplicate call - it's the tool you're deciding whether to approve. +The tool has NOT started executing yet. +- **"success"**: The tool completed successfully in a previous turn +- **"error"**: The tool failed or was cancelled in a previous turn + +**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, +that IS the current tool being evaluated. +It is NOT already executing or a duplicate. + ## Event to Evaluate {event_description} diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py index 1d280f7c1..c3cde475b 100644 --- a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py +++ b/tests/strands/experimental/steering/context_providers/test_ledger_provider.py @@ -87,11 +87,19 @@ def test_ledger_after_tool_call_success(mock_datetime): # Set up existing ledger with pending call existing_ledger = { - "tool_calls": [{"tool_name": "test_tool", "status": "pending", "timestamp": "2024-01-01T12:00:00"}] + "tool_calls": [ + { + "tool_use_id": "test-id", + "tool_name": "test_tool", + "status": "pending", + "timestamp": "2024-01-01T12:00:00", + } + ] } steering_context.data.set("ledger", existing_ledger) event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "test-id"} event.result = {"status": "success", "content": ["success_result"]} event.exception = None @@ -133,3 +141,184 @@ def test_session_start_persistence(): callback = LedgerBeforeToolCall() assert callback.session_start == "2024-01-01T10:00:00" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_all_pending(mock_datetime): + """Test multiple tool calls added as pending before any execute.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + # Add three tool calls in sequence (simulating parallel proposal) + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert len(ledger["tool_calls"]) == 3 + assert all(call["status"] == "pending" for call in ledger["tool_calls"]) + assert [call["tool_name"] for call in ledger["tool_calls"]] == ["tool_a", "tool_b", "tool_c"] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_by_id(mock_datetime): + """Test tool calls complete in any order by matching toolUseId.""" + # Need timestamps for: session_start + 3 tool calls + 1 completion + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_a + "2024-01-01T12:01:00", # tool_b + "2024-01-01T12:02:00", # tool_c + "2024-01-01T12:03:00", # completion + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i, tool_name in enumerate(["tool_a", "tool_b", "tool_c"]): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": tool_name, "input": {}} + before_callback(event, steering_context) + + # Complete middle tool first (out of order) + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "success", "content": ["result_b"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert ledger["tool_calls"][1]["status"] == "success" + assert ledger["tool_calls"][1]["result"] == ["result_b"] + assert ledger["tool_calls"][2]["status"] == "pending" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_complete_all_out_of_order(mock_datetime): + """Test all parallel tool calls complete in reverse order.""" + # Need timestamps for: session_start + 3 tool calls + 3 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # tool_2 + "2024-01-01T12:03:00", # completion tool_2 + "2024-01-01T12:04:00", # completion tool_1 + "2024-01-01T12:05:00", # completion tool_0 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add three pending tool calls + for i in range(3): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # Complete in reverse order: 2, 1, 0 + for i in [2, 1, 0]: + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}"} + event.result = {"status": "success", "content": [f"result_{i}"]} + event.exception = None + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert all(call["status"] == "success" for call in ledger["tool_calls"]) + assert ledger["tool_calls"][0]["result"] == ["result_0"] + assert ledger["tool_calls"][1]["result"] == ["result_1"] + assert ledger["tool_calls"][2]["result"] == ["result_2"] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_parallel_tool_calls_with_failure(mock_datetime): + """Test parallel tool calls where one fails.""" + # Need timestamps for: session_start + 2 tool calls + 2 completions + mock_datetime.now.return_value.isoformat.side_effect = [ + "2024-01-01T11:00:00", # session_start + "2024-01-01T12:00:00", # tool_0 + "2024-01-01T12:01:00", # tool_1 + "2024-01-01T12:02:00", # completion tool_0 + "2024-01-01T12:03:00", # completion tool_1 + ] + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add two pending tool calls + for i in range(2): + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": f"id_{i}", "name": f"tool_{i}", "input": {}} + before_callback(event, steering_context) + + # First succeeds + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_0"} + event.result = {"status": "success", "content": ["result_0"]} + event.exception = None + after_callback(event, steering_context) + + # Second fails + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_1"} + event.result = {"status": "error", "content": []} + event.exception = ValueError("test error") + after_callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "success" + assert ledger["tool_calls"][0]["error"] is None + assert ledger["tool_calls"][1]["status"] == "error" + assert ledger["tool_calls"][1]["error"] == "test error" + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_after_tool_call_no_matching_id(mock_datetime): + """Test AfterToolCallEvent when tool_use_id doesn't match any pending call.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + before_callback = LedgerBeforeToolCall() + after_callback = LedgerAfterToolCall() + steering_context = SteeringContext() + + # Add a pending tool call + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "id_1", "name": "tool_1", "input": {}} + before_callback(event, steering_context) + + # Try to complete a different tool_use_id that doesn't exist + event = Mock(spec=AfterToolCallEvent) + event.tool_use = {"toolUseId": "id_999"} + event.result = {"status": "success", "content": ["result"]} + event.exception = None + after_callback(event, steering_context) + + # Original tool should still be pending (no match found) + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["status"] == "pending" + assert "completion_timestamp" not in ledger["tool_calls"][0] + + +@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +def test_tool_use_id_stored_in_ledger(mock_datetime): + """Test that toolUseId is stored in ledger entries.""" + mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" + + callback = LedgerBeforeToolCall() + steering_context = SteeringContext() + + event = Mock(spec=BeforeToolCallEvent) + event.tool_use = {"toolUseId": "test-id-123", "name": "test_tool", "input": {}} + callback(event, steering_context) + + ledger = steering_context.data.get("ledger") + assert ledger["tool_calls"][0]["tool_use_id"] == "test-id-123" From f2c35a593ddabc0bec73768566d32e4744d15453 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 30 Jan 2026 10:22:46 -0500 Subject: [PATCH 104/279] Clone main metrics upload script for integ tests (#1600) --- .github/workflows/integration-test.yml | 38 ++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index bbcdfde25..00fda1262 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -59,8 +59,42 @@ jobs: run: | hatch test tests_integ - - name: Publish test metrics to CloudWatch + - name: Upload test results if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results + path: ./build/test-results.xml + + upload-metrics: + runs-on: ubuntu-latest + needs: check-access-and-checkout + if: always() + permissions: + id-token: write + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v5 + with: + role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} + aws-region: us-east-1 + mask-aws-account-id: true + + - name: Checkout main + uses: actions/checkout@v6 + with: + ref: main + sparse-checkout: | + .github/scripts + persist-credentials: false + + - name: Download test results + uses: actions/download-artifact@v4 + with: + name: test-results + + - name: Publish test metrics to CloudWatch run: | pip install --no-cache-dir boto3 - python .github/scripts/upload-integ-test-metrics.py ./build/test-results.xml ${{ github.event.repository.name }} + python .github/scripts/upload-integ-test-metrics.py test-results.xml ${{ github.event.repository.name }} From ab51706c2a4131de9ece1964729081512b6847f6 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 30 Jan 2026 11:40:42 -0500 Subject: [PATCH 105/279] Skip location for non bedrock model providers (#1602) --- src/strands/models/_validation.py | 21 ++++++++ src/strands/models/anthropic.py | 7 ++- src/strands/models/bedrock.py | 2 +- src/strands/models/gemini.py | 27 ++++++---- src/strands/models/llamaapi.py | 18 ++++--- src/strands/models/llamacpp.py | 18 ++++--- src/strands/models/mistral.py | 7 ++- src/strands/models/ollama.py | 18 ++++--- src/strands/models/openai.py | 18 ++++--- src/strands/models/writer.py | 16 ++++-- tests/strands/models/test__validation.py | 67 +++++++++++++++++++++++ tests/strands/models/test_anthropic.py | 67 +++++++++++++++++++++++ tests/strands/models/test_bedrock.py | 2 +- tests/strands/models/test_gemini.py | 64 ++++++++++++++++++++++ tests/strands/models/test_llamaapi.py | 67 +++++++++++++++++++++++ tests/strands/models/test_llamacpp.py | 69 ++++++++++++++++++++++++ tests/strands/models/test_mistral.py | 63 ++++++++++++++++++++++ tests/strands/models/test_ollama.py | 66 +++++++++++++++++++++++ tests/strands/models/test_openai.py | 65 ++++++++++++++++++++++ tests/strands/models/test_writer.py | 67 +++++++++++++++++++++++ 20 files changed, 708 insertions(+), 41 deletions(-) create mode 100644 tests/strands/models/test__validation.py diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 1e82bca73..9d4d8b178 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -6,6 +6,7 @@ from typing_extensions import get_type_hints +from ..types.content import ContentBlock from ..types.tools import ToolChoice @@ -41,3 +42,23 @@ def warn_on_tool_choice_not_supported(tool_choice: ToolChoice | None) -> None: "A ToolChoice was provided to this provider but is not supported and will be ignored", stacklevel=4, ) + + +def _has_location_source(content: ContentBlock) -> bool: + """Check if a content block contains a location source. + + Providers need to explicitly define an implementation to support content locations. + + Args: + content: Content block to check. + + Returns: + True if the content block contains an location source, False otherwise. + """ + if "image" in content: + return "location" in content["image"].get("source", {}) + if "document" in content: + return "location" in content["document"].get("source", {}) + if "video" in content: + return "location" in content["video"].get("source", {}) + return False diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 535c820ee..b5f6fcf91 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -20,7 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -189,6 +189,11 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} continue + # Check for location sources in image, document, or video content + if _has_location_source(content): + logger.warning("Location sources are not supported by Anthropic | skipping content block") + continue + formatted_contents.append(self._format_request_message_content(content)) if formatted_contents: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b053b70fb..596936e6f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -472,7 +472,7 @@ def _handle_location(self, location: SourceLocation) -> dict[str, Any] | None: formatted_document_s3["bucketOwner"] = s3_location["bucketOwner"] return {"s3Location": formatted_document_s3} else: - logger.warning("Non s3 location sources are not supported by Bedrock, skipping content block") + logger.warning("Non s3 location sources are not supported by Bedrock | skipping content block") return None def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any] | None: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 192a363d3..6a6535999 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -18,7 +18,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -229,15 +229,24 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten # available in tool result blocks, hence the mapping. tool_use_id_to_name: dict[str, str] = {} - return [ - genai.types.Content( - parts=[ - self._format_request_content_part(content, tool_use_id_to_name) for content in message["content"] - ], - role="user" if message["role"] == "user" else "model", + contents = [] + for message in messages: + parts = [] + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Gemini | skipping content block") + continue + parts.append(self._format_request_content_part(content, tool_use_id_to_name)) + + contents.append( + genai.types.Content( + parts=parts, + role="user" if message["role"] == "user" else "model", + ) ) - for message in messages - ] + + return contents def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index ce0367bf5..b1ed4563a 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -20,7 +20,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -176,12 +176,18 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None for message in messages: contents = message["content"] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by LlamaAPI | skipping content block") + continue + filtered_contents.append(content) + formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = "" - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + formatted_contents = [self._format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) for content in contents diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index ca838f3d7..c52509816 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -30,7 +30,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -299,11 +299,17 @@ def _format_messages(self, messages: Messages, system_prompt: str | None = None) for message in messages: contents = message["content"] - formatted_contents = [ - self._format_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] + # Filter out location sources and unsupported block types + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by llama.cpp | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [self._format_message_content(content) for content in filtered_contents] formatted_tool_calls = [ self._format_tool_call( { diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 4ec77ccfe..504e81c92 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -17,7 +17,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -212,6 +212,11 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None tool_messages: list[dict[str, Any]] = [] for content in contents: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Mistral | skipping content block") + continue + if "text" in content: formatted_content = self._format_request_message_content(content) if isinstance(formatted_content, str): diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 8d72aa534..68aba59d4 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -15,7 +15,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -160,12 +160,16 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None """ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - return system_message + [ - formatted_message - for message in messages - for content in message["content"] - for formatted_message in self._format_request_message_contents(message["role"], content) - ] + formatted_messages = [] + for message in messages: + for content in message["content"]: + # Check for location sources and skip with warning + if _has_location_source(content): + logger.warning("Location sources are not supported by Ollama | skipping content block") + continue + formatted_messages.extend(self._format_request_message_contents(message["role"], content)) + + return system_message + formatted_messages def format_request( self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index d9266212b..51e98c8c2 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -20,7 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys +from ._validation import _has_location_source, validate_config_keys from .model import Model logger = logging.getLogger(__name__) @@ -338,11 +338,17 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic "reasoningContent is not supported in multi-turn conversations with the Chat Completions API." ) - formatted_contents = [ - cls.format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) - ] + # Filter out content blocks that shouldn't be formatted + filtered_contents = [] + for content in contents: + if any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]): + continue + if _has_location_source(content): + logger.warning("Location sources are not supported by OpenAI | skipping content block") + continue + filtered_contents.append(content) + + formatted_contents = [cls.format_request_message_content(content) for content in filtered_contents] formatted_tool_calls = [ cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content ] diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index f306d649b..94774b363 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -18,7 +18,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse -from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import Model logger = logging.getLogger(__name__) @@ -218,11 +218,21 @@ def _format_request_messages(self, messages: Messages, system_prompt: str | None for message in messages: contents = message["content"] + # Filter out location sources + filtered_contents = [] + for content in contents: + if _has_location_source(content): + logger.warning("Location sources are not supported by Writer | skipping content block") + continue + filtered_contents.append(content) + # Only palmyra V5 support multiple content. Other models support only '{"content": "text_content"}' if self.get_config().get("model_id", "") == "palmyra-x5": - formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision(contents) + formatted_contents: str | list[dict[str, Any]] = self._format_request_message_contents_vision( + filtered_contents + ) else: - formatted_contents = self._format_request_message_contents(contents) + formatted_contents = self._format_request_message_contents(filtered_contents) formatted_tool_calls = [ self._format_request_message_tool_call(content["toolUse"]) diff --git a/tests/strands/models/test__validation.py b/tests/strands/models/test__validation.py new file mode 100644 index 000000000..e8a451494 --- /dev/null +++ b/tests/strands/models/test__validation.py @@ -0,0 +1,67 @@ +"""Tests for model validation helper functions.""" + +from strands.models._validation import _has_location_source + + +class TestHasLocationSource: + """Tests for _has_location_source helper function.""" + + def test_image_with_location_source(self): + """Test detection of location source in image content.""" + content = {"image": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_image_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"image": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_document_with_location_source(self): + """Test detection of location source in document content.""" + content = {"document": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_document_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"document": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_video_with_location_source(self): + """Test detection of location source in video content.""" + content = {"video": {"source": {"location": {"type": "s3", "uri": "s3://bucket/key"}}}} + assert _has_location_source(content) + + def test_video_with_bytes_source(self): + """Test that bytes source is not detected as location.""" + content = {"video": {"source": {"bytes": b"data"}}} + assert not _has_location_source(content) + + def test_text_content(self): + """Test that text content is not detected as location source.""" + content = {"text": "hello"} + assert not _has_location_source(content) + + def test_tool_use_content(self): + """Test that toolUse content is not detected as location source.""" + content = {"toolUse": {"name": "test", "input": {}, "toolUseId": "123"}} + assert not _has_location_source(content) + + def test_tool_result_content(self): + """Test that toolResult content is not detected as location source.""" + content = {"toolResult": {"toolUseId": "123", "content": [{"text": "result"}]}} + assert not _has_location_source(content) + + def test_image_without_source(self): + """Test that image without source is not detected as location.""" + content = {"image": {"format": "png"}} + assert not _has_location_source(content) + + def test_document_without_source(self): + """Test that document without source is not detected as location.""" + content = {"document": {"format": "pdf", "name": "test.pdf"}} + assert not _has_location_source(content) + + def test_video_without_source(self): + """Test that video without source is not detected as location.""" + content = {"video": {"format": "mp4"}} + assert not _has_location_source(content) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 74bbb8d45..c5aff8062 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import anthropic @@ -866,3 +867,69 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, model_id, max_tokens, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "look at this image"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text + + +def test_format_request_filters_location_source_document(model, model_id, max_tokens, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + tru_request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + exp_messages = [ + {"role": "user", "content": [{"type": "text", "text": "analyze this document"}]}, + ] + assert tru_request["messages"] == exp_messages + assert "Location sources are not supported by Anthropic" in caplog.text diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 761434258..aac791214 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1924,7 +1924,7 @@ def test_format_request_unsupported_location(model, caplog): formatted_request = model._format_request(messages) assert len(formatted_request["messages"][0]["content"]) == 1 - assert "Non s3 location sources are not supported by Bedrock, skipping content block" in caplog.text + assert "Non s3 location sources are not supported by Bedrock | skipping content block" in caplog.text def test_format_request_video_s3_location(model, model_id): diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 86ab2fea5..d62c5a7c8 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -934,3 +934,67 @@ def test_init_with_both_client_and_client_args_raises_error(): with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.gemini") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages, None, None, None) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["contents"][0]["parts"] + assert len(formatted_content) == 1 + assert "text" in formatted_content[0] + assert "Location sources are not supported by Gemini" in caplog.text diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index a6bbf5673..2bf12d055 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import logging import unittest.mock import pytest @@ -414,3 +415,69 @@ async def test_tool_choice_none_no_warning(model, messages, captured_warnings, a await alist(response) assert len(captured_warnings) == 0 + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.llamaapi") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by LlamaAPI" in caplog.text diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index e5b2614c0..fa784de5c 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -2,6 +2,7 @@ import base64 import json +import logging from unittest.mock import AsyncMock, patch import httpx @@ -637,3 +638,71 @@ def test_format_messages_with_mixed_content() -> None: assert result[0]["content"][2]["type"] == "image_url" assert "image_url" in result[0]["content"][2] assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,") + + +def test_format_request_filters_s3_source_image(caplog) -> None: + """Test that images with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text + + +def test_format_request_filters_location_source_document(caplog) -> None: + """Test that documents with Location sources are filtered out with warning.""" + model = LlamaCppModel() + caplog.set_level(logging.WARNING, logger="strands.models.llamacpp") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model._format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by llama.cpp" in caplog.text diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 7808336f2..ad74bae89 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import pydantic @@ -592,3 +593,65 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Image with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "look at this image" + assert "Location sources are not supported by Mistral" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.mistral") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + formatted_messages = model._format_request_messages(messages) + + # Document with S3 source should be filtered, text should remain + user_content = formatted_messages[0]["content"] + assert user_content == "analyze this document" + assert "Location sources are not supported by Mistral" in caplog.text diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 14db63a24..d17894028 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -1,4 +1,5 @@ import json +import logging import unittest.mock import pydantic @@ -559,3 +560,68 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "look at this image" + assert "images" not in user_message or user_message.get("images") == [] + assert "Location sources are not supported by Ollama" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.ollama") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_message = formatted_messages[0] + assert user_message["content"] == "analyze this document" + assert "Location sources are not supported by Ollama" in caplog.text diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7c1d18998..6eeb477d9 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,3 +1,4 @@ +import logging import unittest.mock import openai @@ -1246,3 +1247,67 @@ def test_init_with_both_client_and_client_args_raises_error(): with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.openai") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_content = request["messages"][0]["content"] + assert len(formatted_content) == 1 + assert formatted_content[0]["type"] == "text" + assert "Location sources are not supported by OpenAI" in caplog.text diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 963904002..81745f412 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,3 +1,4 @@ +import logging import unittest.mock from typing import Any @@ -435,3 +436,69 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_filters_s3_source_image(model, caplog): + """Test that images with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "look at this image"}, + { + "image": { + "format": "png", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/image.png"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Image with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text + + +def test_format_request_filters_location_source_document(model, caplog): + """Test that documents with Location sources are filtered out with warning.""" + caplog.set_level(logging.WARNING, logger="strands.models.writer") + + messages = [ + { + "role": "user", + "content": [ + {"text": "analyze this document"}, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + { + "document": { + "format": "pdf", + "name": "report.pdf", + "source": {"location": {"type": "s3", "uri": "s3://my-bucket/report.pdf"}}, + }, + }, + ], + }, + ] + + request = model.format_request(messages) + + # Document with S3 source should be filtered, text should remain + formatted_messages = request["messages"] + user_content = formatted_messages[0]["content"] + assert len(user_content) == 1 + assert user_content[0]["type"] == "text" + assert "Location sources are not supported by Writer" in caplog.text From e0171cf6f8696398b4de7bc97e5f946a86d42604 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Feb 2026 12:41:18 -0500 Subject: [PATCH 106/279] Add conditional execution for finalize step (#1605) --- .github/workflows/strands-command.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index 6c3328192..6cd43c5c0 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -79,6 +79,7 @@ jobs: write_permission: 'false' finalize: + if: always() needs: [setup-and-process, execute-readonly-agent] permissions: contents: write From 005737906529f597225c74551f696e06ae368f6e Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 2 Feb 2026 16:51:18 -0500 Subject: [PATCH 107/279] interrupts - graph - multiagent nodes (#1606) --- src/strands/multiagent/graph.py | 70 +++++++------ tests/strands/multiagent/test_graph.py | 97 ++++++++++++++++++- .../{test_agent.py => test_node.py} | 12 ++- .../interrupts/multiagent/test_session.py | 28 ++++-- 4 files changed, 166 insertions(+), 41 deletions(-) rename tests_integ/interrupts/multiagent/{test_agent.py => test_node.py} (93%) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index d296753c0..6b135d1a7 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -603,17 +603,20 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) - def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + def _activate_interrupt( + self, node: GraphNode, interrupts: list[Interrupt], from_hook: bool = False + ) -> MultiAgentNodeInterruptEvent: """Activate the interrupt state. Args: node: The interrupted node. interrupts: The interrupts raised by the user. + from_hook: Whether the interrupt originated from a hook (e.g., BeforeNodeCallEvent). Returns: MultiAgentNodeInterruptEvent """ - logger.debug("node=<%s> | node interrupted", node.node_id) + logger.debug("node=<%s>, from_hook=<%s> | node interrupted", node.node_id, from_hook) node.execution_status = Status.INTERRUPTED @@ -622,13 +625,20 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) self._interrupt_state.activate() + + self._interrupt_state.context[node.node_id] = { + "from_hook": from_hook, + "interrupt_ids": [interrupt.id for interrupt in interrupts], + } + if isinstance(node.executor, Agent): - self._interrupt_state.context[node.node_id] = { - "activated": node.executor._interrupt_state.activated, - "interrupt_state": node.executor._interrupt_state.to_dict(), - "state": node.executor.state.get(), - "messages": node.executor.messages, - } + self._interrupt_state.context[node.node_id].update( + { + "interrupt_state": node.executor._interrupt_state.to_dict(), + "state": node.executor.state.get(), + "messages": node.executor.messages, + } + ) return MultiAgentNodeInterruptEvent(node.node_id, interrupts) @@ -866,7 +876,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) start_time = time.time() try: if interrupts: - yield self._activate_interrupt(node, interrupts) + yield self._activate_interrupt(node, interrupts, from_hook=True) return if before_event.cancel_node: @@ -896,20 +906,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - if multi_agent_result.status == Status.INTERRUPTED: - raise NotImplementedError( - f"node_id=<{node.node_id}>, " - "issue= " - "| user raised interrupt from a multi agent node" - ) - node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, + status=multi_agent_result.status, accumulated_usage=multi_agent_result.accumulated_usage, accumulated_metrics=multi_agent_result.accumulated_metrics, execution_count=multi_agent_result.execution_count, + interrupts=multi_agent_result.interrupts, ) elif isinstance(node.executor, Agent): @@ -1040,18 +1044,26 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: """ if self._interrupt_state.activated: context = self._interrupt_state.context - if node.node_id in context and context[node.node_id]["activated"]: - agent_context = context[node.node_id] - agent = cast(Agent, node.executor) - agent.messages = agent_context["messages"] - agent.state = AgentState(agent_context["state"]) - agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"]) - - responses = context["responses"] - interrupts = agent._interrupt_state.interrupts - return [ - response for response in responses if response["interruptResponse"]["interruptId"] in interrupts - ] + if node.node_id in context: + node_context = context[node.node_id] + + # Only route responses if the interrupt originated from the node's execution + if not node_context["from_hook"]: + # Filter responses to only those for this node's interrupts + node_responses = [ + response + for response in context["responses"] + if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"] + ] + + if isinstance(node.executor, MultiAgentBase): + return node_responses + + agent = node.executor + agent.messages = node_context["messages"] + agent.state = AgentState(node_context["state"]) + agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) + return node_responses # Get satisfied dependencies dependency_results = {} diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c511328d4..0fbb102a4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2228,7 +2228,8 @@ def test_graph_interrupt_on_agent(agenerator): ], ) graph._interrupt_state.context["test_agent"] = { - "activated": True, + "from_hook": False, + "interrupt_ids": [interrupt.id], "interrupt_state": { "activated": True, "context": {}, @@ -2259,3 +2260,97 @@ def test_graph_interrupt_on_agent(agenerator): assert len(multiagent_result.results) == 1 agent.stream_async.assert_called_once_with(responses, invocation_state={}) + + +def test_graph_interrupt_on_multiagent(agenerator): + exp_interrupts = [ + Interrupt( + id="test_id", + name="test_name", + reason="test_reason", + ) + ] + + multiagent = create_mock_multi_agent("test_multiagent", "Multi-agent completed") + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={}, + status=Status.INTERRUPTED, + interrupts=exp_interrupts, + ), + }, + ], + ) + + builder = GraphBuilder() + builder.add_node(multiagent, "test_multiagent") + graph = builder.build() + + multiagent_result = graph("Test task") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_multiagent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + multiagent.stream_async = Mock() + multiagent.stream_async.return_value = agenerator( + [ + { + "result": MultiAgentResult( + results={ + "inner_node": NodeResult( + result=AgentResult( + message={"role": "assistant", "content": [{"text": "Inner completed"}]}, + stop_reason="end_turn", + state={}, + metrics={}, + ) + ) + }, + status=Status.COMPLETED, + ), + }, + ], + ) + graph._interrupt_state.context["test_multiagent"] = { + "from_hook": False, + "interrupt_ids": [interrupt.id], + } + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + + multiagent.stream_async.assert_called_once_with(responses, {}) diff --git a/tests_integ/interrupts/multiagent/test_agent.py b/tests_integ/interrupts/multiagent/test_node.py similarity index 93% rename from tests_integ/interrupts/multiagent/test_agent.py rename to tests_integ/interrupts/multiagent/test_node.py index 1a6ad87c6..23e7a62bc 100644 --- a/tests_integ/interrupts/multiagent/test_agent.py +++ b/tests_integ/interrupts/multiagent/test_node.py @@ -65,13 +65,13 @@ def swarm(weather_agent): @pytest.fixture -def graph(info_agent, day_agent, time_agent, weather_agent): +def graph(info_agent, day_agent, time_agent, swarm): builder = GraphBuilder() builder.add_node(info_agent, "info") builder.add_node(day_agent, "day") builder.add_node(time_agent, "time") - builder.add_node(weather_agent, "weather") + builder.add_node(swarm, "weather") builder.add_edge("info", "day") builder.add_edge("info", "time") @@ -82,7 +82,7 @@ def graph(info_agent, day_agent, time_agent, weather_agent): return builder.build() -def test_swarm_interrupt_agent(swarm): +def test_swarm_interrupt_node(swarm): multiagent_result = swarm("What is the weather?") tru_status = multiagent_result.status @@ -122,7 +122,7 @@ def test_swarm_interrupt_agent(swarm): assert "sunny" in weather_message -def test_graph_interrupt_agent(graph): +def test_graph_interrupt_node(graph): multiagent_result = graph("What is the day, time, and weather?") tru_result_status = multiagent_result.status @@ -180,7 +180,9 @@ def test_graph_interrupt_agent(graph): day_message = json.dumps(multiagent_result.results["day"].result.message).lower() time_message = json.dumps(multiagent_result.results["time"].result.message).lower() - weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() assert "monday" in day_message assert "12:01" in time_message + + nested_multiagent_result = multiagent_result.results["weather"].result + weather_message = json.dumps(nested_multiagent_result.results["weather"].result.message).lower() assert "sunny" in weather_message diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index 96b9844bf..8a5979d63 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -72,15 +72,23 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): def test_graph_interrupt_session(weather_tool, tmpdir): + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() multiagent_result = graph("Can you check the weather and then summarize the results?") @@ -105,15 +113,23 @@ def test_graph_interrupt_session(weather_tool, tmpdir): interrupt = multiagent_result.interrupts[0] + parent_sm = FileSessionManager(session_id="parent-session", storage_dir=tmpdir / "parent") + child_sm = FileSessionManager(session_id="child-session", storage_dir=tmpdir / "child") + weather_agent = Agent(name="weather", tools=[weather_tool]) summarizer_agent = Agent(name="summarizer") - session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + weather_builder = GraphBuilder() + weather_builder.add_node(weather_agent, "weather") + weather_builder.set_entry_point("weather") + weather_builder.set_session_manager(child_sm) + weather_graph = weather_builder.build() builder = GraphBuilder() - builder.add_node(weather_agent, "weather") + builder.add_node(weather_graph, "weather") builder.add_node(summarizer_agent, "summarizer") builder.add_edge("weather", "summarizer") - builder.set_session_manager(session_manager) + builder.set_session_manager(parent_sm) graph = builder.build() responses = [ From 51567e6698c0aa9d001b6800b8f3bb83179aa953 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Feb 2026 16:58:30 -0500 Subject: [PATCH 108/279] fix various test warnings (#1613) --- .../strands/agent/hooks/test_agent_events.py | 7 +- tests/strands/agent/test_agent_hooks.py | 3 + tests/strands/event_loop/test_streaming.py | 1 + .../experimental/hooks/test_hook_aliases.py | 16 ++-- .../tools/test_tool_provider_alias.py | 6 +- tests/strands/models/test_bedrock.py | 87 +++++++++---------- tests/strands/models/test_llamacpp.py | 4 +- tests/strands/multiagent/a2a/test_executor.py | 3 + tests/strands/multiagent/test_base.py | 1 + tests/strands/tools/test_loader.py | 8 ++ tests/strands/tools/test_registry.py | 1 + 11 files changed, 81 insertions(+), 56 deletions(-) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index f511c7019..02c367ccc 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -84,7 +84,7 @@ async def test_stream_e2e_success(alist): mock_callback = unittest.mock.Mock() agent = Agent(model=mock_provider, tools=[async_tool, normal_tool, streaming_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) tool_config = { "toolChoice": {"auto": {}}, @@ -344,7 +344,7 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): mock_callback = unittest.mock.Mock() agent = Agent(model=model, tools=[normal_tool], callback_handler=mock_callback) - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) # Base object with common properties throttle_props = { @@ -492,7 +492,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( # Because we're throwing an exception, we manually collect the items here tru_events = [] - stream = agent.stream_async("Do the stuff", arg1=1013) + stream = agent.stream_async("Do the stuff", invocation_state={"arg1": 1013}) async for event in stream: tru_events.append(event) @@ -525,6 +525,7 @@ async def test_event_loop_cycle_text_response_throttling_early_end( assert typed_events == [] +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_structured_output(agenerator): # we use bedrock here as it uses the tool implementation diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 8ff81295a..4397b9628 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -284,6 +284,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert len(agent.messages) == 4 +@pytest.mark.filterwarnings("ignore:Agent.structured_output method is deprecated:DeprecationWarning") def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output.""" @@ -300,6 +301,7 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator): assert len(agent.messages) == 0 # no new messages added +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator): """Verify that the correct hook events are emitted as part of structured_output_async.""" @@ -667,6 +669,7 @@ async def overwrite_input_hook(event: BeforeInvocationEvent): assert agent.messages[0]["content"][0]["text"] == "GOODBYE" +@pytest.mark.filterwarnings("ignore:Agent.structured_output_async method is deprecated:DeprecationWarning") @pytest.mark.asyncio async def test_before_invocation_event_messages_none_in_structured_output(agenerator): """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index b2cc152cb..0fe04f4b2 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -48,6 +48,7 @@ def moto_autouse(moto_env, moto_mock_aws): ), ], ) +@pytest.mark.filterwarnings("ignore:remove_blank_messages_content_text is deprecated:DeprecationWarning") def test_remove_blank_messages_content_text(messages, exp_result): tru_result = strands.event_loop.streaming.remove_blank_messages_content_text(messages) diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index b229c1c2d..ed7adba8a 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -7,16 +7,20 @@ import importlib import sys +import warnings from unittest.mock import Mock import pytest -from strands.experimental.hooks import ( - AfterModelInvocationEvent, - AfterToolInvocationEvent, - BeforeModelInvocationEvent, - BeforeToolInvocationEvent, -) +# Suppress deprecation warnings from imports since we're testing the aliases themselves +with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.hooks import ( + AfterModelInvocationEvent, + AfterToolInvocationEvent, + BeforeModelInvocationEvent, + BeforeToolInvocationEvent, + ) from strands.hooks import ( AfterModelCallEvent, AfterToolCallEvent, diff --git a/tests/strands/experimental/tools/test_tool_provider_alias.py b/tests/strands/experimental/tools/test_tool_provider_alias.py index 58a2b9e20..3b3055bc6 100644 --- a/tests/strands/experimental/tools/test_tool_provider_alias.py +++ b/tests/strands/experimental/tools/test_tool_provider_alias.py @@ -6,6 +6,7 @@ """ import sys +import warnings import pytest @@ -14,7 +15,10 @@ def test_experimental_alias_is_same_type(): """Verify that experimental ToolProvider alias is identical to the actual type.""" - from strands.experimental.tools import ToolProvider as ExperimentalToolProvider + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental.tools import ToolProvider as ExperimentalToolProvider assert ExperimentalToolProvider is ToolProvider diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index aac791214..1410e129b 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -31,7 +31,9 @@ def session_cls(): # Mock the creation of a Session so that we don't depend on environment variables or profiles with unittest.mock.patch.object(strands.models.bedrock.boto3, "Session") as mock_session_cls: - mock_session_cls.return_value.region_name = None + mock_session = unittest.mock.Mock() + mock_session.region_name = None + mock_session_cls.return_value = mock_session yield mock_session_cls @@ -216,66 +218,63 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) -def test__init__default_user_agent(bedrock_client): +def test__init__default_user_agent(session_cls, bedrock_client): """Set user agent when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() + _ = BedrockModel() - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__default_read_timeout(bedrock_client): +def test__init__default_read_timeout(session_cls, bedrock_client): """Set default read timeout when no boto_client_config is provided.""" - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel() - # Verify the client was created with the correct read timeout - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT + _ = BedrockModel() + # Verify the client was created with the correct read timeout + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].read_timeout == DEFAULT_READ_TIMEOUT -def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + +def test__init__with_custom_boto_client_config_no_user_agent(session_cls, bedrock_client): """Set user agent when boto_client_config is provided without user_agent_extra.""" custom_config = BotocoreConfig(read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) + _ = BedrockModel(boto_client_config=custom_config) - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "strands-agents" - assert kwargs["config"].read_timeout == 900 + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 -def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): +def test__init__with_custom_boto_client_config_with_user_agent(session_cls, bedrock_client): """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) - with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: - mock_session = mock_session_cls.return_value - _ = BedrockModel(boto_client_config=custom_config) - - # Verify the client was created with the correct config - mock_session.client.assert_called_once() - args, kwargs = mock_session.client.call_args - assert kwargs["service_name"] == "bedrock-runtime" - assert isinstance(kwargs["config"], BotocoreConfig) - assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" - assert kwargs["config"].read_timeout == 900 + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + client = session_cls.return_value.client + client.assert_called_once() + args, kwargs = client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 def test__init__model_config(bedrock_client): diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index fa784de5c..3e023dfce 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -3,7 +3,7 @@ import base64 import json import logging -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -248,7 +248,7 @@ async def mock_aiter_lines(): mock_response = AsyncMock() mock_response.aiter_lines = mock_aiter_lines - mock_response.raise_for_status = AsyncMock() + mock_response.raise_for_status = MagicMock() with patch.object(model.client, "post", return_value=mock_response): messages = [{"role": "user", "content": [{"text": "Hi"}]}] diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 73ade574e..bb039bdce 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -11,6 +11,9 @@ from strands.multiagent.a2a.executor import StrandsA2AExecutor from strands.types.content import ContentBlock +# Suppress A2A compliance warnings for legacy streaming mode tests +pytestmark = pytest.mark.filterwarnings("ignore:The default A2A response stream.*:UserWarning") + # Test data constants VALID_PNG_BYTES = b"fake_png_data" VALID_MP4_BYTES = b"fake_mp4_data" diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 4e8a5dd06..2fb2cc617 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -156,6 +156,7 @@ def deserialize_state(self, payload: dict) -> None: assert isinstance(agent, MultiAgentBase) +@pytest.mark.filterwarnings("ignore:`\\*\\*kwargs` parameter is deprecating:UserWarning") def test_multi_agent_base_call_method(): """Test that __call__ method properly delegates to invoke_async.""" diff --git a/tests/strands/tools/test_loader.py b/tests/strands/tools/test_loader.py index 1c665b42a..121ebed2d 100644 --- a/tests/strands/tools/test_loader.py +++ b/tests/strands/tools/test_loader.py @@ -10,6 +10,14 @@ from strands.tools.loader import _TOOL_MODULE_PREFIX, ToolLoader, load_tools_from_file_path from strands.tools.tools import PythonAgentTool +# Suppress deprecation warnings for deprecated ToolLoader methods being tested +pytestmark = pytest.mark.filterwarnings( + "ignore:ToolLoader.load_python_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_python_tools is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tool is deprecated:DeprecationWarning", + "ignore:ToolLoader.load_tools is deprecated:DeprecationWarning", +) + @pytest.fixture def tool_path(request, tmp_path, monkeypatch): diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ed96f2b6a..73141beb6 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -13,6 +13,7 @@ from strands.tools.registry import ToolRegistry +@pytest.mark.filterwarnings("ignore:load_tool_from_filepath is deprecated:DeprecationWarning") def test_load_tool_from_filepath_failure(): """Test error handling when load_tool fails.""" tool_registry = ToolRegistry() From 6c468ae2dcf34b07a8ae0db81602d728836862af Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 3 Feb 2026 10:06:04 -0500 Subject: [PATCH 109/279] Fix bedrock file warnings (#1603) From ea1ea1c557b7a52e19e0ab224a946e5ccbf72a26 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Tue, 3 Feb 2026 14:52:08 -0500 Subject: [PATCH 110/279] increase test timeout (#1623) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ba635cc48..048fab88f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,7 +242,7 @@ convention = "google" testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi --junit-xml=build/test-results.xml" -timeout = 45 +timeout = 90 [tool.coverage.run] From 7db79bbeb53847006c1b6caad84dc6862e836477 Mon Sep 17 00:00:00 2001 From: dinindunz <31271518+dinindunz@users.noreply.github.com> Date: Wed, 4 Feb 2026 09:19:52 +1300 Subject: [PATCH 111/279] fix(openai): Handles Bedrock-style context overflow errors for OpenAI-compatible endpoints (#1529) --- src/strands/models/openai.py | 25 +++++++++ tests/strands/models/test_openai.py | 86 +++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 51e98c8c2..ab421e6c7 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -27,6 +27,15 @@ T = TypeVar("T", bound=BaseModel) +# Alternative context overflow error messages +# These are commonly returned by OpenAI-compatible endpoints wrapping other providers +# (e.g., Databricks serving Bedrock models) +_CONTEXT_OVERFLOW_MESSAGES = [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", +] + class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @@ -600,6 +609,14 @@ async def stream( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -723,6 +740,14 @@ async def structured_output( # Rate limits (including TPM) require waiting/retrying, not context reduction logger.warning("OpenAI threw rate limit error") raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Check for alternative context overflow error messages + error_message = str(e) + if any(overflow_msg in error_message for overflow_msg in _CONTEXT_OVERFLOW_MESSAGES): + logger.warning("context window overflow error detected") + raise ContextWindowOverflowException(error_message) from e + # Re-raise other APIError exceptions + raise parsed: T | None = None # Find the first choice with tool_calls diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 6eeb477d9..4f8652632 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1035,6 +1035,92 @@ async def test_stream_context_overflow_exception(openai_client, model, messages) assert exc_info.value.__cause__ == mock_error +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_stream_alternative_context_overflow_messages(openai_client, model, messages, error_message): + """Test that alternative context overflow messages in APIError are properly converted.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that the stream method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + ], +) +async def test_structured_output_alternative_context_overflow_messages( + openai_client, model, messages, test_output_model_cls, error_message +): + """Test that alternative context overflow messages in APIError are properly converted in structured output.""" + # Create a mock OpenAI APIError with alternative context overflow message + mock_error = openai.APIError( + message=error_message, + request=unittest.mock.MagicMock(), + body={"error": {"message": error_message}}, + ) + + # Configure the mock client to raise the APIError + openai_client.beta.chat.completions.parse.side_effect = mock_error + + # Test that the structured_output method converts the error properly + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + # Verify the exception message contains the original error + assert error_message in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_api_error_passthrough(openai_client, model, messages): + """Test that APIError without overflow messages passes through unchanged.""" + # Create a mock OpenAI APIError without overflow message + mock_error = openai.APIError( + message="Some other API error", + request=unittest.mock.MagicMock(), + body={"error": {"message": "Some other API error"}}, + ) + + # Configure the mock client to raise the APIError + openai_client.chat.completions.create.side_effect = mock_error + + # Test that APIError without overflow messages passes through + with pytest.raises(openai.APIError) as exc_info: + async for _ in model.stream(messages): + pass + + # Verify the original exception is raised, not ContextWindowOverflowException + assert exc_info.value == mock_error + + @pytest.mark.asyncio async def test_stream_other_bad_request_errors_passthrough(openai_client, model, messages): """Test that other BadRequestError exceptions are not converted to ContextWindowOverflowException.""" From 570689b2ea4277e67e11e900dec51c31c3544f92 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 4 Feb 2026 10:28:27 -0500 Subject: [PATCH 112/279] feat: make structured output prompt message configurable (#1288) (#1627) --- src/strands/agent/agent.py | 32 +++- src/strands/event_loop/event_loop.py | 2 +- .../tools/structured_output/__init__.py | 3 +- .../_structured_output_context.py | 11 +- .../agent/test_agent_structured_output.py | 157 ++++++++++++++++++ .../test_event_loop_structured_output.py | 56 ++++++- .../test_structured_output_context.py | 31 +++- 7 files changed, 283 insertions(+), 9 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 05c3af191..a76017e75 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -118,6 +118,7 @@ def __init__( state: AgentState | dict | None = None, hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, + structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | None = None, ): @@ -168,6 +169,11 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + structured_output_prompt: Custom prompt message used when forcing structured output. + When using structured output, if the model doesn't automatically use the output tool, + the agent sends a follow-up message to request structured formatting. This parameter + allows customizing that message. + Defaults to "You must format the previous response as structured output." tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). retry_strategy: Strategy for retrying model calls on throttling or other transient errors. Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. @@ -181,6 +187,7 @@ def __init__( # initializing self._system_prompt for backwards compatibility self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model + self._structured_output_prompt = structured_output_prompt self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -338,6 +345,7 @@ def __call__( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -356,6 +364,7 @@ def __call__( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -369,7 +378,11 @@ def __call__( """ return run_async( lambda: self.invoke_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) ) @@ -379,6 +392,7 @@ async def invoke_async( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -397,6 +411,7 @@ async def invoke_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -408,7 +423,11 @@ async def invoke_async( - state: The final state of the event loop """ events = self.stream_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) async for event in events: _ = event @@ -542,6 +561,7 @@ async def stream_async( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -560,6 +580,7 @@ async def stream_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -617,7 +638,7 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, merged_state, structured_output_model) + events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt) async for event in events: event.prepare(invocation_state=merged_state) @@ -645,6 +666,7 @@ async def _run_loop( messages: Messages, invocation_state: dict[str, Any], structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -652,6 +674,7 @@ async def _run_loop( messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt for forcing structured output. Yields: Events from the event loop cycle. @@ -668,7 +691,8 @@ async def _run_loop( await self._append_messages(*messages) structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, ) # Execute the event loop cycle with retry logic for context limits diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 9fe645f80..3113ddb79 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -220,7 +220,7 @@ async def event_loop_cycle( structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") await agent._append_messages( - {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} ) events = recurse_event_loop( diff --git a/src/strands/tools/structured_output/__init__.py b/src/strands/tools/structured_output/__init__.py index 777d5d846..a3a12d000 100644 --- a/src/strands/tools/structured_output/__init__.py +++ b/src/strands/tools/structured_output/__init__.py @@ -1,5 +1,6 @@ """Structured output tools for the Strands Agents framework.""" +from ._structured_output_context import DEFAULT_STRUCTURED_OUTPUT_PROMPT from .structured_output_utils import convert_pydantic_to_tool_spec -__all__ = ["convert_pydantic_to_tool_spec"] +__all__ = ["convert_pydantic_to_tool_spec", "DEFAULT_STRUCTURED_OUTPUT_PROMPT"] diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index 2f8dd8ca0..9a5190d9d 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -13,15 +13,23 @@ logger = logging.getLogger(__name__) +DEFAULT_STRUCTURED_OUTPUT_PROMPT = "You must format the previous response as structured output." + class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: type[BaseModel] | None = None): + def __init__( + self, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, + ): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt message to use when forcing structured output. + Defaults to "You must format the previous response as structured output." """ self.results: dict[str, BaseModel] = {} self.structured_output_model: type[BaseModel] | None = structured_output_model @@ -31,6 +39,7 @@ def __init__(self, structured_output_model: type[BaseModel] | None = None): self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False self.expected_tool_name: str | None = None + self.structured_output_prompt: str = structured_output_prompt or DEFAULT_STRUCTURED_OUTPUT_PROMPT if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index 7341c714e..6ab112048 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -411,3 +411,160 @@ async def mock_product_cycle(*args, **kwargs): mock_event_loop.side_effect = mock_product_cycle result2 = agent("Get product", structured_output_model=product_model) assert result2.structured_output is pm + + +class TestAgentStructuredOutputPrompt: + """Test Agent structured_output_prompt functionality.""" + + def test_agent_init_with_structured_output_prompt(self, user_model): + """Test that Agent can be initialized with a structured_output_prompt.""" + custom_prompt = "Please format your response using the schema." + agent = Agent(structured_output_model=user_model, structured_output_prompt=custom_prompt) + + assert agent._structured_output_prompt == custom_prompt + + def test_agent_init_without_structured_output_prompt(self): + """Test that Agent can be initialized without structured_output_prompt.""" + agent = Agent() + + assert agent._structured_output_prompt is None + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.__call__ uses default structured_output_prompt when not specified.""" + custom_prompt = "Use the output schema to format your response." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default structured_output_prompt + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=custom_prompt, + ) + agent("Get user info") + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_override_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test that invocation-level structured_output_prompt overrides default.""" + default_prompt = "Default prompt for structured output." + override_prompt = "Override prompt for this specific call." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + # Should use override_prompt, not the default + assert structured_output_context.structured_output_prompt == override_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default prompt, but override at call time + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=default_prompt, + ) + agent("Get user info", structured_output_prompt=override_prompt) + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_invocation_prompt_no_default(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test that invocation-level prompt works when no default is set.""" + invocation_prompt = "Format as structured output now." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == invocation_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent without default prompt + agent = Agent(model=mock_model, structured_output_model=user_model) + agent("Get user info", structured_output_prompt=invocation_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_invoke_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.invoke_async with structured_output_prompt.""" + custom_prompt = "Async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + await agent.invoke_async("Get user", structured_output_prompt=custom_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_stream_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.stream_async with structured_output_prompt.""" + custom_prompt = "Stream async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + async for _ in agent.stream_async("Get user", structured_output_prompt=custom_prompt): + pass + + mock_event_loop.assert_called_once() diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 23b7f3433..6f75d6083 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -8,7 +8,10 @@ from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent @@ -190,6 +193,8 @@ async def test_event_loop_forces_structured_output_on_end_turn( mock_agent._append_messages.assert_called_once() args = mock_agent._append_messages.call_args[0][0] assert args["role"] == "user" + # Should use the default prompt + assert args["content"][0]["text"] == DEFAULT_STRUCTURED_OUTPUT_PROMPT # Should have called recurse_event_loop with the context mock_recurse.assert_called_once() @@ -197,6 +202,55 @@ async def test_event_loop_forces_structured_output_on_end_turn( assert call_kwargs["structured_output_context"] == structured_output_context +@pytest.mark.asyncio +async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent, agenerator, alist): + """Test that event loop uses custom prompt when forcing structured output.""" + custom_prompt = "Please format your response as structured data using the output schema." + structured_output_context = StructuredOutputContext( + structured_output_model=UserModel, + structured_output_prompt=custom_prompt, + ) + + # First call returns end_turn without using structured output tool + mock_agent.model.stream.side_effect = [ + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + # Mock recurse_event_loop to return final result + with patch("strands.event_loop.event_loop.recurse_event_loop") as mock_recurse: + mock_stop_event = Mock() + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="John", age=30, email="john@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_recurse.return_value = agenerator([mock_stop_event]) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Should have appended a message with the custom prompt + mock_agent._append_messages.assert_called_once() + args = mock_agent._append_messages.call_args[0][0] + assert args["role"] == "user" + assert args["content"][0]["text"] == custom_prompt + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index 0f1c7ffff..6d75852d1 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -2,7 +2,10 @@ from pydantic import BaseModel, Field -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.tools.structured_output.structured_output_tool import StructuredOutputTool @@ -35,6 +38,7 @@ def test_initialization_with_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT def test_initialization_without_structured_output_model(self): """Test initialization without a structured output model.""" @@ -47,6 +51,31 @@ def test_initialization_without_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_initialization_with_custom_prompt(self): + """Test initialization with a custom structured output prompt.""" + custom_prompt = "Please format your response using the output schema." + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=custom_prompt, + ) + + assert context.structured_output_model == SampleModel + assert context.structured_output_prompt == custom_prompt + + def test_initialization_with_none_prompt_uses_default(self): + """Test that None prompt falls back to default.""" + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=None, + ) + + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_default_prompt_constant_value(self): + """Test the default prompt constant has expected value.""" + assert DEFAULT_STRUCTURED_OUTPUT_PROMPT == "You must format the previous response as structured output." def test_is_enabled_property(self): """Test the is_enabled property.""" From ba1822ca56bdc6999c64703b91262af7b818ff8d Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 4 Feb 2026 10:37:17 -0500 Subject: [PATCH 113/279] fix: Update retry_strategy=None to turn off retries (#1630) Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 32 +++++++++++++++++++------ tests/strands/agent/test_agent_retry.py | 30 ++++++++++++++++++++++- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index a76017e75..299ca2d38 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -78,7 +78,14 @@ class _DefaultCallbackHandlerSentinel: pass +class _DefaultRetryStrategySentinel: + """Sentinel class to distinguish between explicit None and default parameter value for retry_strategy.""" + + pass + + _DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() +_DEFAULT_RETRY_STRATEGY = _DefaultRetryStrategySentinel() _DEFAULT_AGENT_NAME = "Strands Agents" _DEFAULT_AGENT_ID = "default" @@ -120,7 +127,7 @@ def __init__( session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, - retry_strategy: ModelRetryStrategy | None = None, + retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, ): """Initialize the Agent with the specified configuration. @@ -258,14 +265,25 @@ def __init__( # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy - if retry_strategy and type(retry_strategy) is not ModelRetryStrategy: + if ( + retry_strategy is not None + and not isinstance(retry_strategy, _DefaultRetryStrategySentinel) + and type(retry_strategy) is not ModelRetryStrategy + ): raise ValueError("retry_strategy must be an instance of ModelRetryStrategy") - self._retry_strategy = ( - retry_strategy - if retry_strategy is not None - else ModelRetryStrategy(max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY) - ) + # If not provided (using the default), create a new ModelRetryStrategy instance + # If explicitly set to None, disable retries (max_attempts=1 means no retries) + # Otherwise use the passed retry_strategy + if isinstance(retry_strategy, _DefaultRetryStrategySentinel): + self._retry_strategy = ModelRetryStrategy( + max_attempts=MAX_ATTEMPTS, max_delay=MAX_DELAY, initial_delay=INITIAL_DELAY + ) + elif retry_strategy is None: + # If no retry strategy is passed in, then we turn retries off + self._retry_strategy = ModelRetryStrategy(max_attempts=1) + else: + self._retry_strategy = retry_strategy # Initialize session management functionality self._session_manager = session_manager diff --git a/tests/strands/agent/test_agent_retry.py b/tests/strands/agent/test_agent_retry.py index 1b3bc5e9c..15757865a 100644 --- a/tests/strands/agent/test_agent_retry.py +++ b/tests/strands/agent/test_agent_retry.py @@ -14,7 +14,7 @@ def test_agent_with_default_retry_strategy(): - """Test that Agent uses ModelRetryStrategy by default when retry_strategy=None.""" + """Test that Agent uses ModelRetryStrategy by default when retry_strategy is not provided.""" agent = Agent() # Should have a retry_strategy @@ -27,6 +27,16 @@ def test_agent_with_default_retry_strategy(): assert agent._retry_strategy._max_delay == 240 +def test_agent_with_retry_strategy_none_disables_retries(): + """Test that Agent disables retries when retry_strategy=None is explicitly passed.""" + agent = Agent(retry_strategy=None) + + # Should have a retry_strategy with max_attempts=1 (no retries) + assert agent._retry_strategy is not None + assert isinstance(agent._retry_strategy, ModelRetryStrategy) + assert agent._retry_strategy._max_attempts == 1 + + def test_agent_with_custom_model_retry_strategy(): """Test Agent initialization with custom ModelRetryStrategy parameters.""" custom_strategy = ModelRetryStrategy(max_attempts=3, initial_delay=2, max_delay=60) @@ -159,3 +169,21 @@ async def test_event_loop_throttle_event_emitted(mock_sleep): # Should have the correct delay value assert throttle_events[0]["event_loop_throttled_delay"] > 0 + + +@pytest.mark.asyncio +async def test_agent_no_retry_when_retry_strategy_none(mock_sleep): + """Test that Agent does not retry when retry_strategy=None.""" + # Create a model that fails with throttling + model = Mock() + model.stream.side_effect = ModelThrottledException("ThrottlingException") + + # Explicitly disable retries + agent = Agent(model=model, retry_strategy=None) + + with pytest.raises(ModelThrottledException): + result = agent.stream_async("test prompt") + _ = [event async for event in result] + + # Should not have slept at all (no retries) + assert len(mock_sleep.sleep_calls) == 0 From 1c5818dd606cc907facb3d3ddb5419c58f5a2e25 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 4 Feb 2026 11:01:45 -0500 Subject: [PATCH 114/279] feat(graph): Add AgentBase support for A2AAgent compatibility (#1615) --- src/strands/agent/agent.py | 3 +- src/strands/multiagent/graph.py | 38 +++++++++-------- tests/strands/multiagent/test_graph.py | 59 ++++++++++++++++++++++++-- tests_integ/a2a/test_multiagent_a2a.py | 32 ++++++++++++++ 4 files changed, 110 insertions(+), 22 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 299ca2d38..567a92b4a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -59,6 +59,7 @@ from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, SlidingWindowConversationManager, @@ -90,7 +91,7 @@ class _DefaultRetryStrategySentinel: _DEFAULT_AGENT_ID = "default" -class Agent: +class Agent(AgentBase): """Core Agent implementation. An agent orchestrates the following workflow: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6b135d1a7..966d2a0b3 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -26,6 +26,7 @@ from .._async import run_async from ..agent import Agent +from ..agent.base import AgentBase from ..agent.state import AgentState from ..hooks.events import ( AfterMultiAgentInvocationEvent, @@ -161,7 +162,7 @@ class GraphNode: """Represents a node in the graph.""" node_id: str - executor: Agent | MultiAgentBase + executor: AgentBase | MultiAgentBase dependencies: set["GraphNode"] = field(default_factory=set) execution_status: Status = Status.PENDING result: NodeResult | None = None @@ -206,7 +207,7 @@ def __eq__(self, other: Any) -> bool: def _validate_node_executor( - executor: Agent | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None + executor: AgentBase | MultiAgentBase, existing_nodes: dict[str, GraphNode] | None = None ) -> None: """Validate a node executor for graph compatibility. @@ -245,8 +246,8 @@ def __init__(self) -> None: self._session_manager: SessionManager | None = None self._hooks: list[HookProvider] | None = None - def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: - """Add an Agent or MultiAgentBase instance as a node to the graph.""" + def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode: + """Add an AgentBase or MultiAgentBase instance as a node to the graph.""" _validate_node_executor(executor, self.nodes) # Auto-generate node_id if not provided @@ -864,9 +865,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) logger.debug("node_id=<%s> | executing node", node.node_id) # Emit node start event - start_event = MultiAgentNodeStartEvent( - node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent" - ) + node_type = "multiagent" if isinstance(node.executor, MultiAgentBase) else "agent" + start_event = MultiAgentNodeStartEvent(node_id=node.node_id, node_type=node_type) yield start_event before_event, interrupts = await self.hooks.invoke_callbacks_async( @@ -916,8 +916,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) interrupts=multi_agent_result.interrupts, ) - elif isinstance(node.executor, Agent): - # For agents, stream their events and collect result + elif isinstance(node.executor, AgentBase): + # For AgentBase implementations (Agent, A2AAgent, etc.), stream events and collect result agent_response = None async for event in node.executor.stream_async(node_input, invocation_state=invocation_state): # Forward agent events with node context @@ -938,14 +938,18 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0)) + # Handle stop_reason and interrupts (use getattr for AgentBase compatibility) + stop_reason = getattr(agent_response, "stop_reason", "end_turn") + interrupts = getattr(agent_response, "interrupts", None) or [] + node_result = NodeResult( result=agent_response, execution_time=round((time.time() - start_time) * 1000), - status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED, + status=Status.INTERRUPTED if stop_reason == "interrupt" else Status.COMPLETED, accumulated_usage=usage, accumulated_metrics=metrics, execution_count=1, - interrupts=agent_response.interrupts or [], + interrupts=interrupts, ) else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") @@ -1056,13 +1060,13 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: if response["interruptResponse"]["interruptId"] in node_context["interrupt_ids"] ] - if isinstance(node.executor, MultiAgentBase): - return node_responses + # Restore Agent-specific state for interrupt resumption + # Only Agent (not generic AgentBase) supports interrupt state restoration + if isinstance(node.executor, Agent): + node.executor.messages = node_context["messages"] + node.executor.state = AgentState(node_context["state"]) + node.executor._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) - agent = node.executor - agent.messages = node_context["messages"] - agent.state = AgentState(node_context["state"]) - agent._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) return node_responses # Get satisfied dependencies diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 0fbb102a4..8158bf4b1 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -4,7 +4,7 @@ import pytest -from strands.agent import Agent, AgentResult +from strands.agent import Agent, AgentBase, AgentResult from strands.agent.state import AgentState from strands.hooks import AgentInitializedEvent, BeforeNodeCallEvent from strands.hooks.registry import HookProvider, HookRegistry @@ -1103,9 +1103,6 @@ async def test_state_reset_only_with_cycles_enabled(): # Create GraphNode node = GraphNode("test_node", agent) - # Simulate agent being in completed_nodes (as if revisited) - from strands.multiagent.graph import GraphState - state = GraphState() state.completed_nodes.add(node) @@ -2354,3 +2351,57 @@ def test_graph_interrupt_on_multiagent(agenerator): assert len(multiagent_result.results) == 1 multiagent.stream_async.assert_called_once_with(responses, {}) + + +@pytest.mark.asyncio +async def test_graph_with_agentbase_implementation(mock_strands_tracer, mock_use_span): + """Test that Graph accepts any AgentBase implementation (not just Agent).""" + + # Create a minimal AgentBase implementation + class CustomAgentBase: + """Custom AgentBase implementation for testing.""" + + def __init__(self, name: str, response_text: str): + self.name = name + self.id = f"{name}_id" + self._response_text = response_text + + def __call__(self, prompt=None, **kwargs): + return AgentResult( + message={"role": "assistant", "content": [{"text": self._response_text}]}, + stop_reason="end_turn", + state={}, + metrics=Mock( + accumulated_usage={"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + accumulated_metrics={"latencyMs": 100.0}, + ), + ) + + async def invoke_async(self, prompt=None, **kwargs): + return self(prompt, **kwargs) + + async def stream_async(self, prompt=None, **kwargs): + yield {"start": True} + yield {"result": self(prompt, **kwargs)} + + # Verify it satisfies AgentBase protocol + custom_agent = CustomAgentBase("custom", "Custom response") + assert isinstance(custom_agent, AgentBase) + + # Create a regular mock agent + regular_agent = create_mock_agent("regular", "Regular response") + + # Build graph with both + builder = GraphBuilder() + builder.add_node(custom_agent, "custom_node") + builder.add_node(regular_agent, "regular_node") + builder.add_edge("custom_node", "regular_node") + builder.set_entry_point("custom_node") + graph = builder.build() + + result = await graph.invoke_async("Test task") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 2 + assert "custom_node" in result.results + assert "regular_node" in result.results diff --git a/tests_integ/a2a/test_multiagent_a2a.py b/tests_integ/a2a/test_multiagent_a2a.py index 60cbc9ce5..8b0186bc5 100644 --- a/tests_integ/a2a/test_multiagent_a2a.py +++ b/tests_integ/a2a/test_multiagent_a2a.py @@ -6,7 +6,9 @@ import pytest from a2a.client import ClientConfig, ClientFactory +from strands import Agent from strands.agent.a2a_agent import A2AAgent +from strands.multiagent.graph import GraphBuilder, Status @pytest.fixture @@ -70,3 +72,33 @@ async def test_a2a_agent_with_non_streaming_client_config(a2a_server): assert result.stop_reason == "end_turn" finally: await httpx_client.aclose() + + +@pytest.mark.asyncio +async def test_graph_with_a2a_agent_and_regular_agent(a2a_server): + """Test Graph execution with both A2AAgent and regular Agent nodes.""" + # Create A2AAgent pointing to the test server + a2a_agent = A2AAgent(endpoint=a2a_server, name="remote_agent") + + # Create a regular Agent + regular_agent = Agent( + model="us.amazon.nova-lite-v1:0", + system_prompt="You are a summarizer. Summarize the input briefly.", + name="summarizer", + ) + + # Build graph with both agent types + builder = GraphBuilder() + builder.add_node(a2a_agent, "remote") + builder.add_node(regular_agent, "summarizer") + builder.add_edge("remote", "summarizer") + builder.set_entry_point("remote") + graph = builder.build() + + # Execute the graph + result = await graph.invoke_async("Say hello in one sentence") + + assert result.status == Status.COMPLETED + assert result.completed_nodes == 2 + assert "remote" in result.results + assert "summarizer" in result.results From 5c05dcf6b537a43fdbe006e79dc3398eac8e4276 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 4 Feb 2026 13:29:55 -0500 Subject: [PATCH 115/279] Fix openai test (#1624) --- tests_integ/models/test_model_openai.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 99ac49148..d31ef3333 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -1,11 +1,11 @@ import os -import unittest.mock import pydantic import pytest import strands from strands import Agent, tool +from strands.event_loop._retry import ModelRetryStrategy from strands.models.openai import OpenAIModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from tests_integ.models import providers @@ -206,20 +206,19 @@ def test_rate_limit_throttling_integration_no_retries(model): to avoid waiting for the exponential backoff during testing. """ # Patch the event loop constants to disable retries for this test - with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1): - agent = Agent(model=model) + agent = Agent(model=model, retry_strategy=ModelRetryStrategy(max_attempts=1)) - # Create a message that's very long to trigger token-per-minute rate limits - # This should be large enough to exceed TPM limits immediately - very_long_text = "Really long text " * 20000 + # Create a message that's very long to trigger token-per-minute rate limits + # This should be large enough to exceed TPM limits immediately + very_long_text = "Really long text " * 600000 - # This should raise ModelThrottledException without retries - with pytest.raises(ModelThrottledException) as exc_info: - agent(very_long_text) + # This should raise ModelThrottledException without retries + with pytest.raises(ModelThrottledException) as exc_info: + agent(very_long_text) - # Verify it's a rate limit error - error_message = str(exc_info.value).lower() - assert "rate limit" in error_message or "tokens per min" in error_message + # Verify it's a rate limit error + error_message = str(exc_info.value).lower() + assert "rate_limit_exceeded" in error_message def test_content_blocks_handling(model): From a3c2b77d426a6b83db6ec1aa05ba63a25354eaff Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:33:37 -0500 Subject: [PATCH 116/279] ci: bump actions/setup-python from 4 to 6 (#1548) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/publish-lambda-layer.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index 3ad9e9abf..859ddfb76 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -64,7 +64,7 @@ jobs: steps: - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} From aa229122b9792c3910f80ec796bbb60d0b3b82d2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:34:01 -0500 Subject: [PATCH 117/279] ci: bump aws-actions/configure-aws-credentials from 4 to 5 (#1547) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/publish-lambda-layer.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index 859ddfb76..c913abac9 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -69,7 +69,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} aws-region: ${{ matrix.region }} @@ -118,7 +118,7 @@ jobs: steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} aws-region: ${{ matrix.region }} From 18ee1b29697b0091d0c5aaa676a2653305807ab2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:35:53 -0500 Subject: [PATCH 118/279] ci: bump actions/download-artifact from 4 to 7 (#1609) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 00fda1262..8f651018e 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -90,7 +90,7 @@ jobs: persist-credentials: false - name: Download test results - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: test-results From 56ad50de9d28eab818894d4737c12be1798c4ce1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 13:36:08 -0500 Subject: [PATCH 119/279] ci: bump actions/upload-artifact from 4 to 6 (#1608) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 8f651018e..f85d23761 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -61,7 +61,7 @@ jobs: - name: Upload test results if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: test-results path: ./build/test-results.xml From 133434970207ff6644fa83e74f93425d53dd5c76 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 4 Feb 2026 13:57:18 -0500 Subject: [PATCH 120/279] =?UTF-8?q?fix:=20update=20agent=20card=20URL=20wh?= =?UTF-8?q?en=20host/port=20overridden=20in=20A2AServer.ser=E2=80=A6=20(#1?= =?UTF-8?q?626)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Containerized Agent --- src/strands/multiagent/a2a/server.py | 19 ++- tests/strands/multiagent/a2a/test_server.py | 147 ++++++++++++++++++++ 2 files changed, 164 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/a2a/server.py b/src/strands/multiagent/a2a/server.py index 7b4c4c73a..fd90e9787 100644 --- a/src/strands/multiagent/a2a/server.py +++ b/src/strands/multiagent/a2a/server.py @@ -79,6 +79,7 @@ def __init__( # Parse the provided URL to extract components for mounting self.public_base_url, self.mount_path = self._parse_public_url(http_url) self.http_url = http_url.rstrip("/") + "/" + self._http_url_explicit = True # Override mount path if serve_at_root is requested if serve_at_root: @@ -88,6 +89,7 @@ def __init__( self.public_base_url = f"http://{host}:{port}" self.http_url = f"{self.public_base_url}/" self.mount_path = "" + self._http_url_explicit = False self.strands_agent = agent self.name = self.strands_agent.name @@ -253,12 +255,25 @@ def serve( port: The port number to bind the server to. Defaults to 9000. **kwargs: Additional keyword arguments to pass to uvicorn.run. """ + # Update host/port if overridden, and recalculate URLs if http_url wasn't explicitly set + if host is not None: + self.host = host + if port is not None: + self.port = port + + if host is not None or port is not None: + # Only update the URL if it wasn't explicitly set via http_url parameter + # (i.e., if the URL was auto-generated from host/port in __init__) + if not self._http_url_explicit: + self.public_base_url = f"http://{self.host}:{self.port}" + self.http_url = f"{self.public_base_url}/" + try: logger.info("Starting Strands A2A server...") if app_type == "fastapi": - uvicorn.run(self.to_fastapi_app(), host=host or self.host, port=port or self.port, **kwargs) + uvicorn.run(self.to_fastapi_app(), host=self.host, port=self.port, **kwargs) else: - uvicorn.run(self.to_starlette_app(), host=host or self.host, port=port or self.port, **kwargs) + uvicorn.run(self.to_starlette_app(), host=self.host, port=self.port, **kwargs) except KeyboardInterrupt: logger.warning("Strands A2A server shutdown requested (KeyboardInterrupt).") except Exception: diff --git a/tests/strands/multiagent/a2a/test_server.py b/tests/strands/multiagent/a2a/test_server.py index 647fce230..aeb882b19 100644 --- a/tests/strands/multiagent/a2a/test_server.py +++ b/tests/strands/multiagent/a2a/test_server.py @@ -876,3 +876,150 @@ def test_to_fastapi_app_with_app_kwargs(mock_strands_agent): assert isinstance(app, FastAPI) assert app.title == "Custom Agent Title" + + +@patch("uvicorn.run") +def test_serve_with_overridden_host_port_updates_agent_card_url(mock_run, mock_strands_agent): + """Test that serve() with host/port overrides updates the agent card URL. + + This test verifies the fix for issue #1258 where specifying host/port in serve() + did not update the agent card URL, causing clients to fail when trying to connect. + """ + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Verify initial URL from constructor defaults + assert a2a_agent.http_url == "http://127.0.0.1:9000/" + assert a2a_agent.public_base_url == "http://127.0.0.1:9000" + + # Call serve with different host and port + a2a_agent.serve(host="localhost", port=9210) + + # Verify URL was updated to match the actual serve parameters + assert a2a_agent.http_url == "http://localhost:9210/" + assert a2a_agent.public_base_url == "http://localhost:9210" + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 9210 + + # Verify the agent card reflects the updated URL + card = a2a_agent.public_agent_card + assert card.url == "http://localhost:9210/" + + # Verify uvicorn was called with the overridden parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "localhost" + assert kwargs["port"] == 9210 + + +@patch("uvicorn.run") +def test_serve_with_overridden_port_only_updates_url(mock_run, mock_strands_agent): + """Test that serve() with only port override updates the agent card URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Call serve with different port only + a2a_agent.serve(port=8080) + + # Verify URL was updated with the new port + assert a2a_agent.http_url == "http://127.0.0.1:8080/" + assert a2a_agent.port == 8080 + + # Verify uvicorn was called with the correct parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "127.0.0.1" + assert kwargs["port"] == 8080 + + +@patch("uvicorn.run") +def test_serve_with_overridden_host_only_updates_url(mock_run, mock_strands_agent): + """Test that serve() with only host override updates the agent card URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, skills=[]) + + # Call serve with different host only + a2a_agent.serve(host="0.0.0.0") + + # Verify URL was updated with the new host + assert a2a_agent.http_url == "http://0.0.0.0:9000/" + assert a2a_agent.host == "0.0.0.0" + + # Verify uvicorn was called with the correct parameters + mock_run.assert_called_once() + _, kwargs = mock_run.call_args + assert kwargs["host"] == "0.0.0.0" + assert kwargs["port"] == 9000 + + +@patch("uvicorn.run") +def test_serve_with_explicit_http_url_does_not_override_url(mock_run, mock_strands_agent): + """Test that serve() with host/port does not override explicitly set http_url. + + When a user explicitly sets http_url in the constructor (e.g., for load balancer scenarios), + the serve() method should NOT override the URL even if host/port are provided. + """ + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Create server with explicit http_url (simulating load balancer scenario) + a2a_agent = A2AServer( + mock_strands_agent, + host="0.0.0.0", + port=8080, + http_url="https://my-alb.amazonaws.com/agent1", + skills=[], + ) + + # Verify initial URL is the explicit one + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/agent1/" + assert a2a_agent._http_url_explicit is True + + # Call serve with different host/port (the local binding) + a2a_agent.serve(host="0.0.0.0", port=9000) + + # Verify URL was NOT changed (explicit http_url should be preserved) + assert a2a_agent.http_url == "https://my-alb.amazonaws.com/agent1/" + assert a2a_agent.public_base_url == "https://my-alb.amazonaws.com" + + # But host/port should still be updated for the actual binding + assert a2a_agent.host == "0.0.0.0" + assert a2a_agent.port == 9000 + + # Verify the agent card still shows the public URL + card = a2a_agent.public_agent_card + assert card.url == "https://my-alb.amazonaws.com/agent1/" + + +@patch("uvicorn.run") +def test_serve_without_overrides_does_not_change_url(mock_run, mock_strands_agent): + """Test that serve() without host/port parameters does not modify the URL.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + a2a_agent = A2AServer(mock_strands_agent, host="localhost", port=8000, skills=[]) + + # Verify initial URL + assert a2a_agent.http_url == "http://localhost:8000/" + + # Call serve without overrides + a2a_agent.serve() + + # Verify URL was NOT changed + assert a2a_agent.http_url == "http://localhost:8000/" + assert a2a_agent.host == "localhost" + assert a2a_agent.port == 8000 + + +def test_http_url_explicit_flag_set_correctly(mock_strands_agent): + """Test that _http_url_explicit flag is set correctly during initialization.""" + mock_strands_agent.tool_registry.get_all_tools_config.return_value = {} + + # Without explicit http_url + server1 = A2AServer(mock_strands_agent, skills=[]) + assert server1._http_url_explicit is False + + # With explicit http_url + server2 = A2AServer(mock_strands_agent, http_url="http://example.com/agent", skills=[]) + assert server2._http_url_explicit is True From 4f1a8b391b85a0e9479bd8ce380fe0b61dd09d7c Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 5 Feb 2026 10:39:38 -0500 Subject: [PATCH 121/279] test: remove broken MCP transport timeout test (#1635) --- tests_integ/mcp/test_mcp_client.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 4e192c935..130b35529 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -398,30 +398,6 @@ def slow_transport(): assert len(tools) >= 0 # Should work now -@pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == "true", - reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", -) -@pytest.mark.asyncio -async def test_streamable_http_mcp_client_times_out_before_tool(): - """Test an mcp server that timesout before the tool is able to respond.""" - server_thread = threading.Thread( - target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True - ) - server_thread.start() - time.sleep(2) # wait for server to startup completely - - def transport_callback() -> MCPTransport: - return streamablehttp_client(sse_read_timeout=2, url="http://127.0.0.1:8001/mcp") - - streamable_http_client = MCPClient(transport_callback) - with streamable_http_client: - # Test tools - result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") - assert result["status"] == "error" - assert result["content"][0]["text"] == "Tool execution failed: Connection closed" - - def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int): """Starts a proxy that throws a 5XX when a tool call is invoked""" import aiohttp From 42f15c275b96446fb228160321ff00f3eea4e112 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 23:00:32 -0500 Subject: [PATCH 122/279] ci: bump aws-actions/configure-aws-credentials from 5 to 6 (#1632) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 4 ++-- .github/workflows/issue-responder.yml | 2 +- .github/workflows/publish-lambda-layer.yml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index f85d23761..5b154385a 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -32,7 +32,7 @@ jobs: contents: read steps: - name: Configure Credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 @@ -75,7 +75,7 @@ jobs: contents: read steps: - name: Configure Credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_INTEG_TEST_ROLE }} aws-region: us-east-1 diff --git a/.github/workflows/issue-responder.yml b/.github/workflows/issue-responder.yml index c6cba59ab..2efa03117 100644 --- a/.github/workflows/issue-responder.yml +++ b/.github/workflows/issue-responder.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_AGENTCORE_ACTIONS_ROLE }} aws-region: us-west-2 diff --git a/.github/workflows/publish-lambda-layer.yml b/.github/workflows/publish-lambda-layer.yml index c913abac9..73252f0ff 100644 --- a/.github/workflows/publish-lambda-layer.yml +++ b/.github/workflows/publish-lambda-layer.yml @@ -69,7 +69,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} aws-region: ${{ matrix.region }} @@ -118,7 +118,7 @@ jobs: steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v5 + uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_LAMBDA_LAYER_PUBLISHER_ROLE }} aws-region: ${{ matrix.region }} From cc4afb3b53909e9f0f89a6e6d643ee1ff581d1a3 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Fri, 6 Feb 2026 16:35:32 +0200 Subject: [PATCH 123/279] docs: add guidance on using Protocol instead of Callable for extensible interfaces (#1637) --- docs/STYLE_GUIDE.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/STYLE_GUIDE.md b/docs/STYLE_GUIDE.md index 51dc0a73a..c17fb2b76 100644 --- a/docs/STYLE_GUIDE.md +++ b/docs/STYLE_GUIDE.md @@ -57,3 +57,20 @@ logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, m ``` By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. + +## Type Annotations + +### Avoid `Callable` for Extensible Interfaces + +Do not use `Callable` for function type annotations that may need additional parameters in the future. `Callable` signatures are fixed and cannot be expanded without breaking existing implementations. + +```python +# Bad: Cannot add parameters later without breaking all existing implementations +EdgeCondition = Callable[[GraphState], bool] + +# Good: Protocol allows adding optional keyword arguments in the future +class EdgeCondition(Protocol): + def __call__(self, state: GraphState, **kwargs: Any) -> bool: ... +``` + +Using `Protocol` with `**kwargs` allows the interface to evolve by adding new keyword arguments without breaking existing implementations that don't use them. From ecfb864953f51bf92da4da1cb8d3098780854925 Mon Sep 17 00:00:00 2001 From: Luca Chang <131398524+LucaButBoring@users.noreply.github.com> Date: Tue, 10 Feb 2026 07:26:05 -0800 Subject: [PATCH 124/279] feat(mcp): Implement basic support for Tasks (#1475) --- AGENTS.md | 55 ++++ src/strands/tools/mcp/__init__.py | 3 +- src/strands/tools/mcp/mcp_client.py | 308 +++++++++++++++++- src/strands/tools/mcp/mcp_tasks.py | 33 ++ tests/strands/tools/mcp/conftest.py | 59 ++++ tests/strands/tools/mcp/test_mcp_client.py | 32 +- .../tools/mcp/test_mcp_client_contextvar.py | 2 + .../tools/mcp/test_mcp_client_tasks.py | 216 ++++++++++++ tests_integ/mcp/task_echo_server.py | 139 ++++++++ tests_integ/mcp/test_mcp_client_tasks.py | 147 +++++++++ 10 files changed, 945 insertions(+), 49 deletions(-) create mode 100644 src/strands/tools/mcp/mcp_tasks.py create mode 100644 tests/strands/tools/mcp/conftest.py create mode 100644 tests/strands/tools/mcp/test_mcp_client_tasks.py create mode 100644 tests_integ/mcp/task_echo_server.py create mode 100644 tests_integ/mcp/test_mcp_client_tasks.py diff --git a/AGENTS.md b/AGENTS.md index a57286941..9199d50fa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,6 +72,7 @@ strands-agents/ │ │ │ ├── mcp_client.py # MCP client implementation │ │ │ ├── mcp_agent_tool.py # MCP tool wrapper │ │ │ ├── mcp_types.py # MCP type definitions +│ │ │ ├── mcp_tasks.py # Task-augmented execution config │ │ │ └── mcp_instrumentation.py # MCP telemetry │ │ └── structured_output/ # Structured output handling │ │ ├── structured_output_tool.py @@ -413,6 +414,60 @@ hatch test --all # Test all Python versions (3.10-3.13) - Use `pytest.mark.asyncio` for async tests - Keep tests focused and independent +## MCP Tasks (Experimental) + +The SDK supports MCP task-augmented execution for long-running tools. This feature is experimental and aligns with the MCP specification 2025-11-25. + +### Overview + +Task-augmented execution allows tools to run asynchronously with a workflow: +1. Create task via `call_tool_as_task` +2. Poll for completion via `poll_task` +3. Get result via `get_task_result` + +### Configuration + +Enable tasks by passing a `TasksConfig` to `MCPClient`: + +```python +from datetime import timedelta +from strands.tools.mcp import MCPClient, TasksConfig + +# Enable with defaults (ttl=1min, poll_timeout=5min) +client = MCPClient(transport, tasks_config={}) + +# Or configure explicitly +client = MCPClient( + transport, + tasks_config=TasksConfig( + ttl=timedelta(minutes=2), # Task time-to-live + poll_timeout=timedelta(minutes=10), # Polling timeout + ), +) +``` + +### Tool Support Levels + +MCP tools declare their task support via `execution.taskSupport`: +- `TASK_REQUIRED`: Tool must use task-augmented execution +- `TASK_OPTIONAL`: Tool can use tasks if client opts in +- `TASK_FORBIDDEN`: Tool does not support tasks (default) + +### Decision Logic + +Task-augmented execution is used when ALL conditions are met: +1. Client opts in via `tasks_config` (not None) +2. Server advertises task capability (`tasks.requests.tools.call`) +3. Tool's `taskSupport` is `required` or `optional` + +### Key Files + +- `src/strands/tools/mcp/mcp_tasks.py` - `TasksConfig` and defaults +- `src/strands/tools/mcp/mcp_client.py` - Task execution logic (`_call_tool_as_task_and_poll_async`) +- `tests/strands/tools/mcp/test_mcp_client_tasks.py` - Unit tests +- `tests_integ/mcp/test_mcp_client_tasks.py` - Integration tests +- `tests_integ/mcp/task_echo_server.py` - Test server with task support + ## Things to Do - Use explicit return types for all functions diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index cfa841c46..8d2c1daa2 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -8,6 +8,7 @@ from .mcp_agent_tool import MCPAgentTool from .mcp_client import MCPClient, ToolFilters +from .mcp_tasks import TasksConfig from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ToolFilters"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 833d55e07..f064f7def 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -47,6 +47,7 @@ from ..tool_provider import ToolProvider from .mcp_agent_tool import MCPAgentTool from .mcp_instrumentation import mcp_instrumentation +from .mcp_tasks import DEFAULT_TASK_CONFIG, DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL, TasksConfig from .mcp_types import MCPToolResult, MCPTransport logger = logging.getLogger(__name__) @@ -116,6 +117,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -126,6 +128,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + tasks_config: Configuration for MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools that support it. + See TasksConfig for details. This feature is experimental and subject to change. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -150,6 +155,16 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() + # Task support configuration and caching + self._tasks_config = tasks_config + self._server_task_capable: bool | None = None + + # Conditionally set up the task support cache (old SDK versions don't expose TaskExecutionMode) + if self._is_tasks_enabled(): + from mcp.types import TaskExecutionMode + + self._tool_task_support_cache: dict[str, TaskExecutionMode] = {} + def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -354,6 +369,8 @@ async def _set_close_event() -> None: self._loaded_tools = None self._tool_provider_started = False self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} if self._close_exception: exception = self._close_exception @@ -396,6 +413,13 @@ async def _list_tools_async() -> ListToolsResult: mcp_tools = [] for tool in list_tools_response.tools: + if self._is_tasks_enabled(): + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support or "forbidden" + # Apply prefix if specified if effective_prefix: prefixed_name = f"{effective_prefix}_{tool.name}" @@ -535,6 +559,46 @@ async def _list_resource_templates_async() -> ListResourceTemplatesResult: return list_resource_templates_result + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + + async def _call_as_task() -> MCPCallToolResult: + # When task-augmented execution is used, use the read_timeout_seconds parameter + # (which is a timedelta) for the polling timeout. + return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout=read_timeout_seconds) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) + + return _call_tool_direct() + def call_tool_sync( self, tool_use_id: str, @@ -544,10 +608,8 @@ def call_tool_sync( ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. If the MCP tool returns - structured content, it will be included as the last item in the content array - of the returned ToolResult. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -562,13 +624,9 @@ def call_tool_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: logger.exception("tool execution failed") @@ -583,8 +641,8 @@ async def call_tool_async( ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the MCPToolResult format. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -599,13 +657,9 @@ async def call_tool_async( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - future = self._invoke_on_background_thread(_call_tool_async()) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -683,6 +737,21 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("session initialized successfully") # Store the session for use while we await the close event self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + # Signal that the session has been created and is ready for use self._init_future.set_result(None) @@ -894,3 +963,206 @@ def _is_session_active(self) -> bool: return False return True + + def _is_tasks_enabled(self) -> bool: + """Check if tasks feature is enabled. + + Tasks are enabled if tasks config is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._tasks_config is not None + + def _get_task_config(self) -> TasksConfig: + """Returns the task execution configuration, configured with defaults if not specified.""" + task_config = self._tasks_config or DEFAULT_TASK_CONFIG + return TasksConfig( + ttl=task_config.get("ttl", DEFAULT_TASK_TTL), + poll_timeout=task_config.get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT), + ) + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + return self._server_task_capable or False + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Task-augmented execution requires: + 1. tasks config is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Opt-in check: tasks must be explicitly enabled via tasks config + if not self._is_tasks_enabled(): + return False + + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_OPTIONAL, TASK_REQUIRED + + # Server capability check (per MCP spec) + if not self._has_server_task_support(): + return False + + # Tool-level capability check (cached during list_tools_sync) + task_support = self._tool_task_support_cache.get(tool_name) + + # Use tasks for TASK_REQUIRED or TASK_OPTIONAL when server supports + if task_support == TASK_REQUIRED or task_support == TASK_OPTIONAL: + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl: timedelta | None = None, + poll_timeout: timedelta | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl: Task time-to-live. Uses configured value if not specified. + poll_timeout: Timeout for polling. Uses configured value if not specified. + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + # Local import to avoid errors on old SDK versions that don't support Tasks + from mcp.types import TASK_STATUS_CANCELLED, TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, GetTaskResult + + session = cast(ClientSession, self._background_thread_session) + + # Precedence: arg > config > default + timeout = poll_timeout or self._get_task_config().get("poll_timeout", DEFAULT_TASK_POLL_TIMEOUT) + ttl = ttl or self._get_task_config().get("ttl", DEFAULT_TASK_TTL) + ttl_ms = int(ttl.total_seconds() * 1000) + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl_ms) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl_ms, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> GetTaskResult | None: + """Inner function to poll task status until terminal state.""" + final = None + async for task in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + task.status, + ) + final = task + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout.total_seconds()) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout_seconds=<%s> | task polling timed out", + name, + task_id, + timeout.total_seconds(), + ) + return self._create_task_error_result( + f"Task {task_id} polling timed out after {timeout.total_seconds()} seconds" + ) + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == TASK_STATUS_FAILED: + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == TASK_STATUS_CANCELLED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == TASK_STATUS_COMPLETED: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/src/strands/tools/mcp/mcp_tasks.py b/src/strands/tools/mcp/mcp_tasks.py new file mode 100644 index 000000000..36537f7df --- /dev/null +++ b/src/strands/tools/mcp/mcp_tasks.py @@ -0,0 +1,33 @@ +"""Task-augmented tool execution configuration for MCP. + +This module provides configuration types and defaults for the experimental MCP Tasks feature. +""" + +from datetime import timedelta + +from typing_extensions import TypedDict + + +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + When enabled, supported tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Warning: + This is an experimental feature in the 2025-11-25 MCP specification and + both the specification and the Strands Agents implementation of this + feature are subject to change. + + Attributes: + ttl: Task time-to-live. Defaults to 1 minute. + poll_timeout: Timeout for polling task completion. Defaults to 5 minutes. + """ + + ttl: timedelta + poll_timeout: timedelta + + +DEFAULT_TASK_TTL = timedelta(minutes=1) +DEFAULT_TASK_POLL_TIMEOUT = timedelta(minutes=5) +DEFAULT_TASK_CONFIG = TasksConfig(ttl=DEFAULT_TASK_TTL, poll_timeout=DEFAULT_TASK_POLL_TIMEOUT) diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py new file mode 100644 index 000000000..0cfce470a --- /dev/null +++ b/tests/strands/tools/mcp/conftest.py @@ -0,0 +1,59 @@ +"""Shared fixtures and helpers for MCP client tests.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_transport(): + """Create a mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create a mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + # Default: no task support (get_server_capabilities is sync, not async!) + mock_session.get_server_capabilities = MagicMock(return_value=None) + + # Create a mock context manager for ClientSession + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + # Patch ClientSession to return our mock session + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +def create_server_capabilities(has_task_support: bool) -> MagicMock: + """Create mock server capabilities. + + Args: + has_task_support: Whether the server should advertise task support. + + Returns: + MagicMock representing server capabilities. + """ + caps = MagicMock() + if has_task_support: + caps.tasks = MagicMock() + caps.tasks.requests = MagicMock() + caps.tasks.requests.tools = MagicMock() + caps.tasks.requests.tools.call = MagicMock() + else: + caps.tasks = None + return caps diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index a2ef369ea..e477c64d5 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from mcp import ListToolsResult @@ -25,35 +25,7 @@ from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError - -@pytest.fixture -def mock_transport(): - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_transport_cm = AsyncMock() - mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) - mock_transport_callable = MagicMock(return_value=mock_transport_cm) - - return { - "read_stream": mock_read_stream, - "write_stream": mock_write_stream, - "transport_cm": mock_transport_cm, - "transport_callable": mock_transport_callable, - } - - -@pytest.fixture -def mock_session(): - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - - # Create a mock context manager for ClientSession - mock_session_cm = AsyncMock() - mock_session_cm.__aenter__.return_value = mock_session - - # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): - yield mock_session +# Fixtures mock_transport and mock_session are imported from conftest.py @pytest.fixture diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py index d95929b02..739796366 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -37,6 +37,8 @@ def mock_session(): """Create mock MCP session.""" mock_session = AsyncMock() mock_session.initialize = AsyncMock() + # get_server_capabilities is sync, not async + mock_session.get_server_capabilities = MagicMock(return_value=None) mock_session_cm = AsyncMock() mock_session_cm.__aenter__.return_value = mock_session diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..01d3b2763 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -0,0 +1,216 @@ +"""Tests for MCP task-augmented execution support in MCPClient.""" + +import asyncio +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mcp import ListToolsResult +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import TextContent as MCPTextContent +from mcp.types import Tool as MCPTool +from mcp.types import ToolExecution + +from strands.tools.mcp import MCPClient, TasksConfig +from strands.tools.mcp.mcp_tasks import DEFAULT_TASK_POLL_TIMEOUT, DEFAULT_TASK_TTL + +from .conftest import create_server_capabilities + + +class TestTasksOptIn: + """Tests for task opt-in behavior via tasks config.""" + + @pytest.mark.parametrize( + "tasks_config,expected_enabled", + [ + (None, False), + ({}, True), + ], + ) + def test_tasks_enabled_state(self, mock_transport, mock_session, tasks_config, expected_enabled): + """Test _is_tasks_enabled based on tasks config.""" + with MCPClient(mock_transport["transport_callable"], tasks_config=tasks_config) as client: + assert client._is_tasks_enabled() is expected_enabled + + def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): + """Test that _should_use_task returns False without opt-in even with server/tool support.""" + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + assert client._should_use_task("test_tool") is False + + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + assert client._should_use_task("test_tool") is True + + +class TestTaskConfiguration: + """Tests for task-related configuration options.""" + + @pytest.mark.parametrize( + "config,expected_ttl,expected_timeout", + [ + ({}, DEFAULT_TASK_TTL, DEFAULT_TASK_POLL_TIMEOUT), + ({"ttl": timedelta(seconds=120)}, timedelta(seconds=120), DEFAULT_TASK_POLL_TIMEOUT), + ({"poll_timeout": timedelta(seconds=60)}, DEFAULT_TASK_TTL, timedelta(seconds=60)), + ( + {"ttl": timedelta(seconds=120), "poll_timeout": timedelta(seconds=60)}, + timedelta(seconds=120), + timedelta(seconds=60), + ), + ], + ) + def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): + """Test task configuration values with various configs.""" + with MCPClient(mock_transport["transport_callable"], tasks_config=config) as client: + config_actual = client._get_task_config() + assert config_actual.get("ttl") == expected_ttl + assert config_actual.get("poll_timeout") == expected_timeout + + def test_stop_resets_task_caches(self, mock_transport, mock_session): + """Test that stop() resets the task support caches.""" + with MCPClient(mock_transport["transport_callable"], tasks_config={}) as client: + client._server_task_capable = True + client._tool_task_support_cache["tool1"] = "required" + assert client._server_task_capable is None + assert client._tool_task_support_cache == {} + + +class TestTaskExecution: + """Tests for task execution and error handling.""" + + def _setup_task_tool(self, mock_session, tool_name: str) -> None: + """Helper to set up a mock task-enabled tool.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ("unknown_status", None, "unexpected task status"), + ], + ) + def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses.""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def mock_poll_task(task_id): + yield MagicMock(status=status, statusMessage=status_message) + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_polling_timeout(self, mock_transport, mock_session): + """Test that task polling times out properly.""" + self._setup_task_tool(mock_session, "slow_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(seconds=0.1)) + ) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): + """Test that read_timeout_seconds overrides the default poll timeout.""" + self._setup_task_tool(mock_session, "timeout_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], tasks_config=TasksConfig(poll_timeout=timedelta(minutes=5)) + ) as client: + client.list_tools_sync() + result = await client.call_tool_async( + tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1) + ) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) + assert result["status"] == "error" + assert "result retrieval failed" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_empty_poll_result(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing.""" + self._setup_task_tool(mock_session, "empty_poll_tool") + + async def empty_poll(task_id): + return + yield # noqa: B901 + + mock_session.experimental.poll_task = empty_poll + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + assert result["status"] == "error" + assert "without status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_successful_completion(self, mock_transport, mock_session): + """Test successful task completion.""" + self._setup_task_tool(mock_session, "success_tool") + + async def poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) + assert result["status"] == "success" + assert "Done" in result["content"][0].get("text", "") diff --git a/tests_integ/mcp/task_echo_server.py b/tests_integ/mcp/task_echo_server.py new file mode 100644 index 000000000..4a8edc97d --- /dev/null +++ b/tests_integ/mcp/task_echo_server.py @@ -0,0 +1,139 @@ +"""MCP server with task-augmented tool execution support for integration testing.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + + +def create_task_server() -> Server: + """Create and configure the task-supporting MCP server.""" + server = Server("task-echo-server") + server.experimental.enable_tasks() + + # Workaround: MCP Python SDK's enable_tasks() doesn't properly set tasks.requests.tools.call capability + original_update_capabilities = server.experimental.update_capabilities + + def patched_update_capabilities(capabilities: types.ServerCapabilities) -> None: + original_update_capabilities(capabilities) + if capabilities.tasks and capabilities.tasks.requests and capabilities.tasks.requests.tools: + capabilities.tasks.requests.tools.call = types.TasksCallCapability() + + server.experimental.update_capabilities = patched_update_capabilities # type: ignore[method-assign] + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="task_required_echo", + description="Echo that requires task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="task_optional_echo", + description="Echo that optionally supports task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_OPTIONAL), + ), + types.Tool( + name="task_forbidden_echo", + description="Echo that does not support task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_FORBIDDEN), + ), + types.Tool( + name="echo", + description="Simple echo without task support setting", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ), + ] + + async def handle_task_required_echo(arguments: dict[str, Any]) -> types.CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + message = arguments.get("message", "") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing echo...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Task echo: {message}")]) + + return await ctx.experimental.run_task(work) + + async def handle_task_optional_echo(arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + ctx = server.request_context + message = arguments.get("message", "") + + if ctx.experimental.is_task: + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing optional task echo...") + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Task optional echo: {message}")] + ) + + return await ctx.experimental.run_task(work) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Direct optional echo: {message}")] + ) + + async def handle_task_forbidden_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Forbidden echo: {message}")]) + + async def handle_simple_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Simple echo: {message}")]) + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + handlers = { + "task_required_echo": handle_task_required_echo, + "task_optional_echo": handle_task_optional_echo, + "task_forbidden_echo": handle_task_forbidden_echo, + "echo": handle_simple_echo, + } + if name in handlers: + return await handlers[name](arguments) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], isError=True + ) + + return server + + +def create_starlette_app(port: int) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Create the Starlette app with MCP session manager.""" + server = create_task_server() + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=app_lifespan), session_manager + + +@click.command() +@click.option("--port", default=8010, help="Port to listen on") +def main(port: int) -> int: + """Start the task echo server.""" + import uvicorn + + starlette_app, _ = create_starlette_app(port) + print(f"Starting task echo server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..b2623c6a1 --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -0,0 +1,147 @@ +"""Integration tests for MCP task-augmented tool execution.""" + +import os +import socket +import threading +import time +from typing import Any + +import pytest +from mcp.client.streamable_http import streamablehttp_client + +from strands.tools.mcp import MCPClient, MCPTransport, TasksConfig + + +def _find_available_port() -> int: + """Find an available port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + return s.getsockname()[1] + + +def start_task_server(port: int) -> None: + """Start the task echo server in a thread.""" + import uvicorn + + from tests_integ.mcp.task_echo_server import create_starlette_app + + starlette_app, _ = create_starlette_app(port) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="warning") + + +@pytest.fixture(scope="module") +def task_server_port() -> int: + return _find_available_port() + + +@pytest.fixture(scope="module") +def task_server(task_server_port: int) -> Any: + """Start the task server for the test module.""" + server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) + server_thread.start() + time.sleep(2) + yield + + +@pytest.fixture +def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks enabled.""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback, tasks_config=TasksConfig()) + + +@pytest.fixture +def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks disabled (default).""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) + + +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport failing in CI") +class TestMCPTaskSupport: + """Integration tests for MCP task-augmented execution.""" + + def test_direct_call_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use direct call_tool (forbidden or no taskSupport).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='forbidden' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_forbidden_echo", arguments={"message": "Hello!"} + ) + assert r1["status"] == "success" + assert "Forbidden echo: Hello!" in r1["content"][0].get("text", "") + + # Tool without taskSupport + r2 = task_mcp_client.call_tool_sync(tool_use_id="t2", name="echo", arguments={"message": "Simple!"}) + assert r2["status"] == "success" + assert "Simple echo: Simple!" in r2["content"][0].get("text", "") + + def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use task-augmented execution (required or optional).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='required' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_required_echo", arguments={"message": "Required!"} + ) + assert r1["status"] == "success" + assert "Task echo: Required!" in r1["content"][0].get("text", "") + + # Tool with taskSupport='optional' + r2 = task_mcp_client.call_tool_sync( + tool_use_id="t2", name="task_optional_echo", arguments={"message": "Optional!"} + ) + assert r2["status"] == "success" + assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") + + def test_task_support_tool_detection(self, task_mcp_client: MCPClient) -> None: + """Test tool-level task support detection.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Verify decision logic + assert task_mcp_client._should_use_task("task_required_echo") is True + assert task_mcp_client._should_use_task("task_optional_echo") is True + assert task_mcp_client._should_use_task("task_forbidden_echo") is False + assert task_mcp_client._should_use_task("echo") is False + + def test_server_capabilities(self, task_mcp_client: MCPClient) -> None: + """Test server task capability detection.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._has_server_task_support() is True + + def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: + """Test that tasks are disabled when experimental.tasks is not configured.""" + with task_mcp_client_disabled: + task_mcp_client_disabled.list_tools_sync() + + assert task_mcp_client_disabled._is_tasks_enabled() is False + assert task_mcp_client_disabled._should_use_task("task_required_echo") is False + + # Tool calls still work via direct call_tool + result = task_mcp_client_disabled.call_tool_sync( + tool_use_id="t", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="t", name="task_forbidden_echo", arguments={"message": "Async!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Async!" in result["content"][0].get("text", "") From 3348099ff8186f59e5a44653e6e328bea16a48b7 Mon Sep 17 00:00:00 2001 From: punkyoon <11442383+punkyoon@users.noreply.github.com> Date: Wed, 11 Feb 2026 05:42:49 +0900 Subject: [PATCH 125/279] fix(multiagent): set empty text part data in `parts` for `Artifact` (#1643) Co-authored-by: Aaron Farntrog --- src/strands/multiagent/a2a/executor.py | 5 +- tests/strands/multiagent/a2a/test_executor.py | 81 +++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 58dfcc045..2f8de99f7 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -191,16 +191,15 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task if self.enable_a2a_compliant_streaming: if self._is_first_chunk: final_content = str(result) if result else "" - parts = [Part(root=TextPart(text=final_content))] if final_content else [] await updater.add_artifact( - parts, + [Part(root=TextPart(text=final_content))], artifact_id=self._current_artifact_id, name="agent_response", last_chunk=True, ) else: await updater.add_artifact( - [], + [Part(root=TextPart(text=""))], artifact_id=self._current_artifact_id, name="agent_response", append=True, diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index bb039bdce..932f26247 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1116,3 +1116,84 @@ async def test_a2a_compliant_mode_uses_add_artifact(mock_strands_agent): assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-123" assert mock_updater.add_artifact.call_args[1]["append"] is False mock_updater.update_status.assert_not_called() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_first_chunk_with_content(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with content when first chunk and result has content.""" + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-456" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.__str__ = MagicMock(return_value="Final response") + + await executor._handle_agent_result(mock_result, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "Final response" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-456" + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + mock_updater.complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_first_chunk_with_none_result(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with empty string when first chunk and result is None. + + Per the A2A spec, parts must contain at least one part, so even with no result + we should send a TextPart with an empty string rather than an empty list. + """ + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-789" + executor._is_first_chunk = True + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + await executor._handle_agent_result(None, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-789" + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + mock_updater.complete.assert_called_once() + + +@pytest.mark.asyncio +async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent): + """Test that A2A-compliant mode sends a TextPart with empty string when not the first chunk. + + Per the A2A spec, parts must contain at least one part, so the final marker + chunk should include a TextPart with an empty string rather than an empty list. + """ + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + executor._current_artifact_id = "artifact-abc" + executor._is_first_chunk = False + + mock_updater = MagicMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.complete = AsyncMock() + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.__str__ = MagicMock(return_value="Some content") + + await executor._handle_agent_result(mock_result, mock_updater) + + mock_updater.add_artifact.assert_called_once() + parts = mock_updater.add_artifact.call_args[0][0] + assert len(parts) == 1 + assert parts[0].root.text == "" + assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc" + assert mock_updater.add_artifact.call_args[1]["append"] is True + assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True + mock_updater.complete.assert_called_once() From 18a349cb67b079645aed244754574a78e3339646 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 11 Feb 2026 10:32:40 -0500 Subject: [PATCH 126/279] fix(summarizing_conversation_manager): use model stream to generate summary (#1653) --- .../summarizing_conversation_manager.py | 95 ++++++-- .../test_summarizing_conversation_manager.py | 220 ++++++++++++++---- 2 files changed, 256 insertions(+), 59 deletions(-) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index cc71e4d88..12b04dcea 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -5,6 +5,8 @@ from typing_extensions import override +from ..._async import run_async +from ...event_loop.streaming import process_stream from ...tools._tool_helpers import noop_tool from ...tools.registry import ToolRegistry from ...types.content import Message @@ -176,9 +178,17 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. + When a dedicated summarization_agent was provided at init time, it is invoked as before + (full agent pipeline, tool execution, etc.). + + In the default case (no summarization_agent), the parent agent's *model* is called + directly via ``model.stream()``. This avoids re-entering the agent pipeline which + would deadlock on ``_invocation_lock`` and corrupt metrics / traces / interrupt state. + Args: messages: The messages to summarize. - agent: The agent instance to use for summarization. + agent: The agent instance whose model will be used for summarization when no + dedicated summarization_agent was configured. Returns: A message containing the conversation summary. @@ -186,26 +196,32 @@ def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: Raises: Exception: If summary generation fails. """ - # Choose which agent to use for summarization - summarization_agent = self.summarization_agent if self.summarization_agent is not None else agent + if self.summarization_agent is not None: + return self._generate_summary_with_agent(messages) + + return self._generate_summary_with_model(messages, agent) + + # ------------------------------------------------------------------ + # Path 1 – dedicated summarization agent (backward-compatible) + # ------------------------------------------------------------------ + + def _generate_summary_with_agent(self, messages: list[Message]) -> Message: + """Generate a summary using the dedicated summarization agent. + + Args: + messages: The messages to summarize. + + Returns: + A message containing the conversation summary. + """ + summarization_agent = self.summarization_agent + assert summarization_agent is not None # guaranteed by caller - # Save original system prompt, messages, and tool registry to restore later original_system_prompt = summarization_agent.system_prompt original_messages = summarization_agent.messages.copy() original_tool_registry = summarization_agent.tool_registry try: - # Only override system prompt if no agent was provided during initialization - if self.summarization_agent is None: - # Use custom system prompt if provided, otherwise use default - system_prompt = ( - self.summarization_system_prompt - if self.summarization_system_prompt is not None - else DEFAULT_SUMMARIZATION_PROMPT - ) - # Temporarily set the system prompt for summarization - summarization_agent.system_prompt = system_prompt - # Add no-op tool if agent has no tools to satisfy tool spec requirement if not summarization_agent.tool_names: tool_registry = ToolRegistry() @@ -214,16 +230,61 @@ def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: summarization_agent.messages = messages - # Use the agent to generate summary with rich content (can use tools if needed) result = summarization_agent("Please summarize this conversation.") return cast(Message, {**result.message, "role": "user"}) finally: - # Restore original agent state summarization_agent.system_prompt = original_system_prompt summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry + # ------------------------------------------------------------------ + # Path 2 – default case: call model.stream() directly + # ------------------------------------------------------------------ + + def _generate_summary_with_model(self, messages: list[Message], agent: "Agent") -> Message: + """Generate a summary by calling the agent's model directly. + + This bypasses the full agent pipeline (lock, metrics, traces, tool loop) and + simply asks the underlying model to summarize the conversation. + + Args: + messages: The messages to summarize. + agent: The parent agent whose model is used. + + Returns: + A message containing the conversation summary. + """ + system_prompt = ( + self.summarization_system_prompt + if self.summarization_system_prompt is not None + else DEFAULT_SUMMARIZATION_PROMPT + ) + + # Build the message list: conversation history + summarization request + summarization_messages = list(messages) + [ + {"role": "user", "content": [{"text": "Please summarize this conversation."}]} + ] + + async def _call_model() -> Message: + chunks = agent.model.stream( + summarization_messages, + tool_specs=None, + system_prompt=system_prompt, + ) + + result_message: Message | None = None + async for event in process_stream(chunks): + if "stop" in event: + _, result_message, _, _ = event["stop"] + + if result_message is None: + raise RuntimeError("Failed to generate summary: no response from model") + return result_message + + message = run_async(_call_model) + return cast(Message, {**message, "role": "user"}) + def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index 4b69e6653..8347e07b4 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -4,20 +4,49 @@ import pytest from strands.agent.agent import Agent -from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.conversation_manager.summarizing_conversation_manager import ( + DEFAULT_SUMMARIZATION_PROMPT, + SummarizingConversationManager, +) from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException from tests.fixtures.mocked_model_provider import MockedModelProvider +async def _mock_model_stream(response_text): + """Create an async generator that yields stream events for a text response. + + This simulates what a real Model.stream() returns so that process_stream() can + reconstruct the assistant message. + """ + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + +async def _mock_model_stream_error(error): + """Async generator that raises an exception, simulating a model failure.""" + raise error + yield # pragma: no cover – makes this a generator + + class MockAgent: - """Mock agent for testing summarization.""" + """Mock agent for testing summarization. + + In the default path (no summarization_agent) the manager now calls + ``agent.model.stream()`` directly, so the model attribute must return a + proper async iterable. When used as a *summarization_agent* the manager + still calls ``agent("…")``, so the ``__call__`` interface is kept. + """ def __init__(self, summary_response="This is a summary of the conversation."): self.summary_response = summary_response self.system_prompt = None self.messages = [] self.model = Mock() + self.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(self.summary_response)) self.call_tracker = Mock() self.tool_registry = Mock() self.tool_names = [] @@ -149,11 +178,12 @@ def test_reduce_context_insufficient_messages_for_summarization(mock_agent): def test_reduce_context_raises_on_summarization_failure(): - """Test that reduce_context raises exception when summarization fails.""" - # Create an agent that will fail + """Test that reduce_context raises exception when model.stream() fails.""" failing_agent = Mock() - failing_agent.side_effect = Exception("Agent failed") - failing_agent.system_prompt = None + failing_agent.model = Mock() + failing_agent.model.stream = Mock( + side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed")) + ) failing_agent_messages: Messages = [ {"role": "user", "content": [{"text": "Message 1"}]}, {"role": "assistant", "content": [{"text": "Response 1"}]}, @@ -207,13 +237,13 @@ def test_generate_summary_with_tool_content(summarizing_manager, mock_agent): assert "text" in summary_content and summary_content["text"] == "This is a summary of the conversation." -def test_generate_summary_raises_on_agent_failure(): - """Test that _generate_summary raises exception when agent fails.""" +def test_generate_summary_raises_on_model_failure(): + """Test that _generate_summary raises exception when model.stream() fails.""" failing_agent = Mock() - failing_agent.side_effect = Exception("Agent failed") - failing_agent.system_prompt = None - empty_failing_messages: Messages = [] - failing_agent.messages = empty_failing_messages + failing_agent.model = Mock() + failing_agent.model.stream = Mock( + side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed")) + ) manager = SummarizingConversationManager() @@ -222,7 +252,7 @@ def test_generate_summary_raises_on_agent_failure(): {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Should raise the exception from the agent + # Should raise the exception from the model with pytest.raises(Exception, match="Agent failed"): manager._generate_summary(messages, failing_agent) @@ -325,8 +355,8 @@ def test_uses_summarization_agent_when_provided(): summary_agent.call_tracker.assert_called_once() -def test_uses_parent_agent_when_no_summarization_agent(): - """Test that parent agent is used when no summarization_agent is provided.""" +def test_default_path_calls_model_directly(): + """Test that the default path (no summarization_agent) calls model.stream() directly.""" manager = SummarizingConversationManager() messages: Messages = [ @@ -337,16 +367,36 @@ def test_uses_parent_agent_when_no_summarization_agent(): parent_agent = create_mock_agent("Parent agent summary") summary = manager._generate_summary(messages, parent_agent) - # Should use the parent agent + # Should use the model directly (via model.stream) summary_content = summary["content"][0] assert "text" in summary_content and summary_content["text"] == "Parent agent summary" - # Assert that the parent agent was called - parent_agent.call_tracker.assert_called_once() + # model.stream() should have been called + parent_agent.model.stream.assert_called_once() + + # The agent itself should NOT have been called (no re-entrant invocation) + parent_agent.call_tracker.assert_not_called() + + +def test_default_path_passes_correct_system_prompt(): + """Test that the default path passes the correct system prompt to model.stream().""" + manager = SummarizingConversationManager() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent() + manager._generate_summary(messages, parent_agent) + # Verify model.stream() was called with the default summarization system prompt + call_kwargs = parent_agent.model.stream.call_args + assert call_kwargs.kwargs["system_prompt"] == DEFAULT_SUMMARIZATION_PROMPT -def test_uses_custom_system_prompt(): - """Test that custom system prompt is used when provided.""" + +def test_default_path_uses_custom_system_prompt(): + """Test that custom system prompt is passed to model.stream() in default path.""" custom_prompt = "Custom system prompt for summarization" manager = SummarizingConversationManager(summarization_system_prompt=custom_prompt) mock_agent = create_mock_agent() @@ -356,16 +406,15 @@ def test_uses_custom_system_prompt(): {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Capture the agent's system prompt changes - original_prompt = mock_agent.system_prompt manager._generate_summary(messages, mock_agent) - # The agent's system prompt should be restored after summarization - assert mock_agent.system_prompt == original_prompt + # Verify model.stream() was called with the custom system prompt + call_kwargs = mock_agent.model.stream.call_args + assert call_kwargs.kwargs["system_prompt"] == custom_prompt -def test_agent_state_restoration(): - """Test that agent state is properly restored after summarization.""" +def test_default_path_does_not_modify_agent_state(): + """Test that the default path does not modify any agent state.""" manager = SummarizingConversationManager() mock_agent = create_mock_agent() @@ -374,6 +423,7 @@ def test_agent_state_restoration(): original_messages: Messages = [{"role": "user", "content": [{"text": "Original message"}]}] mock_agent.system_prompt = original_system_prompt mock_agent.messages = original_messages.copy() + original_tool_registry = mock_agent.tool_registry messages: Messages = [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -382,33 +432,99 @@ def test_agent_state_restoration(): manager._generate_summary(messages, mock_agent) - # State should be restored + # Agent state should be completely untouched assert mock_agent.system_prompt == original_system_prompt assert mock_agent.messages == original_messages + assert mock_agent.tool_registry is original_tool_registry -def test_agent_state_restoration_on_exception(): - """Test that agent state is restored even when summarization fails.""" +def test_default_path_does_not_modify_agent_state_on_exception(): + """Test that agent state is untouched when model.stream() fails in default path.""" manager = SummarizingConversationManager() - # Create an agent that fails during summarization mock_agent = Mock() mock_agent.system_prompt = "Original prompt" agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] mock_agent.messages = agent_messages - mock_agent.side_effect = Exception("Summarization failed") + mock_agent.model = Mock() + mock_agent.model.stream = Mock( + side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Summarization failed")) + ) messages: Messages = [ {"role": "user", "content": [{"text": "Hello"}]}, {"role": "assistant", "content": [{"text": "Hi there"}]}, ] - # Should restore state even on exception with pytest.raises(Exception, match="Summarization failed"): manager._generate_summary(messages, mock_agent) - # State should still be restored + # Agent state should be untouched (default path never modifies it) assert mock_agent.system_prompt == "Original prompt" + assert mock_agent.messages == agent_messages + + +def test_default_path_passes_no_tool_specs(): + """Test that model.stream() is called with tool_specs=None in default path.""" + manager = SummarizingConversationManager() + + messages: Messages = [{"role": "user", "content": [{"text": "test"}]}] + agent = create_mock_agent() + + manager._generate_summary(messages, agent) + + # model.stream() should be called with tool_specs=None + call_kwargs = agent.model.stream.call_args + assert call_kwargs.kwargs.get("tool_specs") is None or call_kwargs[0][1] is None + + +def test_agent_path_state_restoration_with_summarization_agent(): + """Test that summarization_agent state is properly restored after summarization.""" + summary_agent = create_mock_agent("Summary from dedicated agent") + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + # Set initial state on the summarization agent + original_system_prompt = "Agent original prompt" + original_messages: Messages = [{"role": "user", "content": [{"text": "Agent original message"}]}] + summary_agent.system_prompt = original_system_prompt + summary_agent.messages = original_messages.copy() + original_tool_registry = summary_agent.tool_registry + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + parent_agent = create_mock_agent("Should not be used") + manager._generate_summary(messages, parent_agent) + + # Summarization agent state should be restored + assert summary_agent.system_prompt == original_system_prompt + assert summary_agent.messages == original_messages + assert summary_agent.tool_registry is original_tool_registry + + +def test_agent_path_state_restoration_on_exception(): + """Test that summarization_agent state is restored even when it fails.""" + summary_agent = Mock() + summary_agent.system_prompt = "Original prompt" + agent_messages: Messages = [{"role": "user", "content": [{"text": "Original"}]}] + summary_agent.messages = agent_messages + summary_agent.side_effect = Exception("Summarization failed") + summary_agent.tool_names = [] + + manager = SummarizingConversationManager(summarization_agent=cast("Agent", summary_agent)) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + with pytest.raises(Exception, match="Summarization failed"): + manager._generate_summary(messages, cast("Agent", Mock())) + + # State should still be restored + assert summary_agent.system_prompt == "Original prompt" def test_reduce_context_tool_pair_adjustment_works_with_forward_search(): @@ -613,27 +729,47 @@ def test_summarizing_conversation_manager_properly_records_removed_message_count @patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") -def test_summarizing_conversation_manager_generate_summary_with_noop_tool(mock_registry_cls, summarizing_manager): +def test_summarizing_conversation_manager_generate_summary_with_noop_tool_agent_path( + mock_registry_cls, +): + """Test noop tool registration when using the agent path (summarization_agent provided).""" mock_registry = mock_registry_cls.return_value + summary_agent = create_mock_agent() + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summary_agent, + ) + messages = [{"role": "user", "content": [{"text": "test"}]}] - agent = create_mock_agent() + parent_agent = create_mock_agent() - original_tool_registry = agent.tool_registry - summarizing_manager._generate_summary(messages, agent) + original_tool_registry = summary_agent.tool_registry + manager._generate_summary(messages, parent_agent) - assert original_tool_registry == agent.tool_registry + assert original_tool_registry == summary_agent.tool_registry mock_registry.register_tool.assert_called_once() @patch("strands.agent.conversation_manager.summarizing_conversation_manager.ToolRegistry") -def test_summarizing_conversation_manager_generate_summary_with_tools(mock_registry_cls, summarizing_manager): +def test_summarizing_conversation_manager_generate_summary_with_tools_agent_path( + mock_registry_cls, +): + """Test no noop tool registration when summarization_agent has tools.""" mock_registry = mock_registry_cls.return_value + summary_agent = create_mock_agent() + summary_agent.tool_names = ["test_tool"] + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summary_agent, + ) + messages = [{"role": "user", "content": [{"text": "test"}]}] - agent = create_mock_agent() - agent.tool_names = ["test_tool"] + parent_agent = create_mock_agent() - summarizing_manager._generate_summary(messages, agent) + manager._generate_summary(messages, parent_agent) mock_registry.register_tool.assert_not_called() From 66fb30852810a55f58597ffc052b84669b801711 Mon Sep 17 00:00:00 2001 From: Elad Ben Avraham Date: Wed, 11 Feb 2026 18:03:13 +0200 Subject: [PATCH 127/279] =?UTF-8?q?fix(bedrock):=20add=20'prompt=20is=20to?= =?UTF-8?q?o=20long'=20to=20context=20window=20overflow=20mes=E2=80=A6=20(?= =?UTF-8?q?#1663)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/strands/models/bedrock.py | 1 + tests/strands/models/test_bedrock.py | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 596936e6f..db1878108 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -44,6 +44,7 @@ "Input is too long for requested model", "input length and `max_tokens` exceed context limit", "too many total text bytes", + "prompt is too long", ] # Models that should include tool result status (include_tool_result_status = True) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 1410e129b..228d6c138 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -21,7 +21,7 @@ DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, ) -from strands.types.exceptions import ModelThrottledException +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") @@ -1517,6 +1517,31 @@ async def test_add_note_on_validation_exception_throughput(bedrock_client, model ] +@pytest.mark.parametrize( + "overflow_message", + [ + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + "prompt is too long: 903884 tokens > 200000 maximum", + ], +) +@pytest.mark.asyncio +async def test_stream_context_window_overflow(overflow_message, bedrock_client, model, alist, messages): + """Test that ClientError with overflow messages raises ContextWindowOverflowException.""" + error_response = { + "Error": { + "Code": "ValidationException", + "Message": f"An error occurred (ValidationException) when calling the ConverseStream operation: " + f"The model returned the following errors: {overflow_message}", + } + } + bedrock_client.converse_stream.side_effect = ClientError(error_response, "ConverseStream") + + with pytest.raises(ContextWindowOverflowException): + await alist(model.stream(messages)) + + @pytest.mark.asyncio async def test_stream_logging(bedrock_client, model, messages, caplog, alist): """Test that stream method logs debug messages at the expected stages.""" From a43e936b8db026a1acc13c9441d612f2ec8b5895 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 11 Feb 2026 13:53:20 -0500 Subject: [PATCH 128/279] fix: fix mcp tests (#1664) --- tests_integ/mcp/test_mcp_client_tasks.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index b2623c6a1..751fb655f 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -129,12 +129,18 @@ def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> assert task_mcp_client_disabled._is_tasks_enabled() is False assert task_mcp_client_disabled._should_use_task("task_required_echo") is False - # Tool calls still work via direct call_tool + # Direct call_tool still works for tools that support it result = task_mcp_client_disabled.call_tool_sync( - tool_use_id="t", name="task_required_echo", arguments={"message": "Direct!"} + tool_use_id="t", name="task_optional_echo", arguments={"message": "Direct!"} ) assert result["status"] == "success" + # Task-required tools fail gracefully via direct call + result2 = task_mcp_client_disabled.call_tool_sync( + tool_use_id="t2", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result2["status"] == "error" + @pytest.mark.asyncio async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: """Test async tool calls.""" From c4503d1f37400d26a87947d006b39d114ed7c000 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Wed, 11 Feb 2026 15:12:33 -0600 Subject: [PATCH 129/279] feat: Propagate exceptions to AfterToolCallEvent for decorated tools (#1565) (#1566) --- src/strands/tools/decorator.py | 9 +- src/strands/tools/executors/_executor.py | 7 +- src/strands/types/_events.py | 15 ++- .../strands/tools/executors/test_executor.py | 92 +++++++++++++++++++ tests/strands/tools/test_decorator.py | 77 ++++++++++++++++ 5 files changed, 191 insertions(+), 9 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 04c14e452..70552d6ba 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -620,6 +620,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_msg}"}], }, + exception=e, ) except Exception as e: # Return error result with exception details for any other error @@ -632,14 +633,15 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw "status": "error", "content": [{"text": f"Error: {error_type} - {error_msg}"}], }, + exception=e, ) - def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: + def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | None = None) -> ToolResultEvent: # FORMAT THE RESULT for Strands Agent if isinstance(result, dict) and "status" in result and "content" in result: # Result is already in the expected format, just add toolUseId result["toolUseId"] = tool_use_d - return ToolResultEvent(cast(ToolResult, result)) + return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format # Always include at least one content item for consistency @@ -648,7 +650,8 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any) -> ToolResultEvent: "toolUseId": tool_use_d, "status": "success", "content": [{"text": str(result)}], - } + }, + exception=exception, ) @property diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index ef000fbd6..0da6b5715 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -215,6 +215,9 @@ async def _stream( return if structured_output_context.is_enabled: kwargs["structured_output_context"] = structured_output_context + + exception: Exception | None = None + async for event in selected_tool.stream(tool_use, invocation_state, **kwargs): # Internal optimization; for built-in AgentTools, we yield TypedEvents out of .stream() # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. @@ -227,6 +230,8 @@ async def _stream( return if isinstance(event, ToolResultEvent): + # Preserve exception from decorated tools before extracting tool_result + exception = event.exception # below the last "event" must point to the tool_result event = event.tool_result break @@ -239,7 +244,7 @@ async def _stream( result = cast(ToolResult, event) after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, selected_tool, tool_use, invocation_state, result, exception=exception ) # Check if retry requested (getattr for BidiAfterToolCallEvent compatibility) diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 0896d48e1..5b0ae78f6 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -276,13 +276,18 @@ def prepare(self, invocation_state: dict) -> None: class ToolResultEvent(TypedEvent): """Event emitted when a tool execution completes.""" - def __init__(self, tool_result: ToolResult) -> None: - """Initialize with the completed tool result. + def __init__(self, tool_result: ToolResult, exception: Exception | None = None) -> None: + """Initialize tool result event.""" + super().__init__({"type": "tool_result", "tool_result": tool_result}) + self._exception = exception - Args: - tool_result: Final result from the tool execution + @property + def exception(self) -> Exception | None: + """The original exception that occurred, if any. + + Can be used for re-raising or type-based error handling. """ - super().__init__({"type": "tool_result", "tool_result": tool_result}) + return self._exception @property def tool_use_id(self) -> str: diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 78e35c2aa..4a5479503 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -482,6 +482,98 @@ async def test_executor_stream_updates_invocation_state_with_agent( assert empty_invocation_state["agent"] is agent +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_exception_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exceptions from @tool-decorated functions reach AfterToolCallEvent.""" + exception = ValueError("decorated tool error") + + @strands.tool(name="decorated_error_tool") + def failing_tool(): + """A tool that raises an exception.""" + raise exception + + agent.tool_registry.register_tool(failing_tool) + tool_use = {"name": "decorated_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_runtime_error_in_hook( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that RuntimeError from @tool-decorated functions reach AfterToolCallEvent.""" + exception = RuntimeError("runtime error from decorated tool") + + @strands.tool(name="runtime_error_tool") + def runtime_error_tool(): + """A tool that raises a RuntimeError.""" + raise exception + + agent.tool_registry.register_tool(runtime_error_tool) + tool_use = {"name": "runtime_error_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is exception + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_no_exception_on_success( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when decorated tool succeeds.""" + + @strands.tool(name="success_decorated_tool") + def success_tool(): + """A tool that succeeds.""" + return "success" + + agent.tool_registry.register_tool(success_tool) + tool_use = {"name": "success_decorated_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "success" + + +@pytest.mark.asyncio +async def test_executor_stream_decorated_tool_error_result_without_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that exception is None when a tool returns an error result without throwing.""" + + @strands.tool(name="error_result_tool") + def error_result_tool(): + """A tool that returns an error result dict without raising.""" + return {"status": "error", "content": [{"text": "something went wrong"}]} + + agent.tool_registry.register_tool(error_result_tool) + tool_use = {"name": "error_result_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.result["status"] == "error" + + @pytest.mark.asyncio async def test_executor_stream_no_retry_set(executor, agent, tool_results, invocation_state, alist): """Test default behavior when retry is not set - tool executes once.""" diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 42213fcb8..f3d6eda02 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1825,6 +1825,83 @@ def inner_default_tool(name: str, level: Annotated[int, Field(description="A lev return f"{name} is at level {level}" +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_runtime_error(alist): + """Test that ToolResultEvent carries exception when tool raises RuntimeError.""" + + @strands.tool + def error_tool(): + """Tool that raises a RuntimeError.""" + raise RuntimeError("test runtime error") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, RuntimeError) + assert str(result_event.exception) == "test runtime error" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_value_error(alist): + """Test that ToolResultEvent carries exception when tool raises ValueError.""" + + @strands.tool + def validation_error_tool(): + """Tool that raises a ValueError.""" + raise ValueError("validation failed") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(validation_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert hasattr(result_event, "exception") + assert isinstance(result_event.exception, ValueError) + assert str(result_event.exception) == "validation failed" + assert result_event.tool_result["status"] == "error" + + +@pytest.mark.asyncio +async def test_tool_result_event_no_exception_on_success(alist): + """Test that ToolResultEvent.exception is None when tool succeeds.""" + + @strands.tool + def success_tool(): + """Tool that succeeds.""" + return "success" + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(success_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.exception is None + assert result_event.tool_result["status"] == "success" + + +@pytest.mark.asyncio +async def test_tool_result_event_carries_exception_assertion_error(alist): + """Test that ToolResultEvent carries AssertionError for unexpected failures.""" + + @strands.tool + def assertion_error_tool(): + """Tool that raises an AssertionError.""" + raise AssertionError("unexpected assertion failure") + + tool_use = {"toolUseId": "test-id", "input": {}} + events = await alist(assertion_error_tool.stream(tool_use, {})) + + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert isinstance(result_event.exception, AssertionError) + assert "unexpected assertion failure" in str(result_event.exception) + assert result_event.tool_result["status"] == "error" + + def test_tool_nullable_required_field_preserves_anyof(): """Test that a required nullable field preserves anyOf so the model can pass null. From 723ee6a1f0eb1a919fb8f840322b2ad351c3d346 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 12 Feb 2026 13:23:33 -0500 Subject: [PATCH 130/279] feat(workflows): add conventional commit workflow in PR (#1645) --- .github/workflows/pr-title.yml | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 .github/workflows/pr-title.yml diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml new file mode 100644 index 000000000..14b18afa6 --- /dev/null +++ b/.github/workflows/pr-title.yml @@ -0,0 +1,37 @@ +name: PR Title Conventional Commits + +on: + pull_request: + branches: [main] + types: [opened, edited, synchronize, reopened] + +jobs: + validate-pr-title: + runs-on: ubuntu-latest + permissions: + pull-requests: read + steps: + - name: Check PR title follows conventional commits + uses: amannn/action-semantic-pull-request@v5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + types: | + feat + fix + docs + style + refactor + perf + test + build + ci + chore + revert + requireScope: false + subjectPattern: ^[a-z].+$ + subjectPatternError: | + The subject "{subject}" must start with a lowercase letter. + ignoreLabels: | + bot + dependencies From 174d787006a58b3232e6b2e5e4d76e1e4a8823b1 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Fri, 13 Feb 2026 10:29:48 -0500 Subject: [PATCH 131/279] fix: the A2AAgent returns empty AgentResult content (#1675) --- src/strands/multiagent/a2a/_converters.py | 5 ++- .../test_summarizing_conversation_manager.py | 8 +--- .../strands/multiagent/a2a/test_converters.py | 40 +++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py index b818c824b..22c2ffb72 100644 --- a/src/strands/multiagent/a2a/_converters.py +++ b/src/strands/multiagent/a2a/_converters.py @@ -105,8 +105,9 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: for part in update_event.status.message.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) - # Handle initial task or task without update event - elif update_event is None and task and hasattr(task, "artifacts") and task.artifacts is not None: + + # Use task.artifacts when no content was extracted from the event + if not content and task and hasattr(task, "artifacts") and task.artifacts is not None: for artifact in task.artifacts: if hasattr(artifact, "parts"): for part in artifact.parts: diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index 8347e07b4..b105eba86 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -181,9 +181,7 @@ def test_reduce_context_raises_on_summarization_failure(): """Test that reduce_context raises exception when model.stream() fails.""" failing_agent = Mock() failing_agent.model = Mock() - failing_agent.model.stream = Mock( - side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed")) - ) + failing_agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed"))) failing_agent_messages: Messages = [ {"role": "user", "content": [{"text": "Message 1"}]}, {"role": "assistant", "content": [{"text": "Response 1"}]}, @@ -241,9 +239,7 @@ def test_generate_summary_raises_on_model_failure(): """Test that _generate_summary raises exception when model.stream() fails.""" failing_agent = Mock() failing_agent.model = Mock() - failing_agent.model.stream = Mock( - side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed")) - ) + failing_agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Agent failed"))) manager = SummarizingConversationManager() diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index 002ebf6a6..c3b310065 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -182,6 +182,46 @@ def test_convert_task_status_update_event(): assert result.message["content"][0]["text"] == "Status message" +def test_convert_task_status_update_event_no_message_falls_back_to_task_artifacts(): + """Test that TaskStatusUpdateEvent with no message falls back to task.artifacts.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Artifact content" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + mock_event = MagicMock(spec=TaskStatusUpdateEvent) + mock_status = MagicMock() + mock_status.message = None + mock_event.status = mock_status + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Artifact content" + + +def test_convert_task_artifact_update_event_empty_parts_falls_back_to_task_artifacts(): + """Test that TaskArtifactUpdateEvent with empty parts falls back to task.artifacts.""" + mock_task = MagicMock() + mock_part = MagicMock() + mock_part.root.text = "Full artifact content" + mock_artifact = MagicMock() + mock_artifact.parts = [mock_part] + mock_task.artifacts = [mock_artifact] + + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + mock_event_artifact = MagicMock() + mock_event_artifact.parts = [] + mock_event.artifact = mock_event_artifact + + result = convert_response_to_agent_result((mock_task, mock_event)) + + assert len(result.message["content"]) == 1 + assert result.message["content"][0]["text"] == "Full artifact content" + + def test_convert_response_handles_missing_data(): """Test that response conversion handles missing/malformed data gracefully.""" # TaskArtifactUpdateEvent with no artifact From 5742d82d7643df05d8e101b309476162aae3938b Mon Sep 17 00:00:00 2001 From: mehtarac Date: Fri, 13 Feb 2026 10:56:48 -0500 Subject: [PATCH 132/279] auto run review workflow on maintainer PR (#1673) --- .github/workflows/auto-strands-review.yml | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/workflows/auto-strands-review.yml diff --git a/.github/workflows/auto-strands-review.yml b/.github/workflows/auto-strands-review.yml new file mode 100644 index 000000000..68190f7a0 --- /dev/null +++ b/.github/workflows/auto-strands-review.yml @@ -0,0 +1,49 @@ +name: Auto Strands Review + +on: + pull_request_target: + branches: [main] + types: [opened, synchronize, reopened, ready_for_review] + +jobs: + authorization-check: + name: Check access + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.auth.outputs.result }} + steps: + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main + with: + skip-check: false + username: ${{ github.event.pull_request.user.login || 'invalid' }} + allowed-roles: 'triage,write,admin' + + trigger-review: + name: Trigger Strands Review + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + actions: write + contents: read + runs-on: ubuntu-latest + steps: + - name: Trigger Strands Command Workflow + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + await github.rest.actions.createWorkflowDispatch({ + owner: context.repo.owner, + repo: context.repo.repo, + workflow_id: 'strands-command.yml', + ref: 'main', + inputs: { + issue_id: String(context.payload.pull_request.number), + command: 'review', + session_id: '' + } + }); + console.log(`Triggered /strands review for PR #${context.payload.pull_request.number}`); From 9d972f8219edc65b4bc438b71b7c69a99f3c7b6c Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:41:55 -0500 Subject: [PATCH 133/279] fix: correct output reference for approval-env in integration test (#1685) --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 5b154385a..789f4506a 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -12,7 +12,7 @@ jobs: permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.auth.outputs.result }} + approval-env: ${{ steps.auth.outputs.approval-env }} steps: - name: Check Authorization id: auth From 634d604076ad76081a8118087e35f5230de72ee6 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 16 Feb 2026 09:53:25 -0500 Subject: [PATCH 134/279] fix: update approval env var for strands agent workflows (#1701) --- .github/workflows/auto-strands-review.yml | 2 +- .github/workflows/strands-command.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/auto-strands-review.yml b/.github/workflows/auto-strands-review.yml index 68190f7a0..44ca1ea03 100644 --- a/.github/workflows/auto-strands-review.yml +++ b/.github/workflows/auto-strands-review.yml @@ -11,7 +11,7 @@ jobs: permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.auth.outputs.result }} + approval-env: ${{ steps.auth.outputs.approval-env }} steps: - name: Check Authorization id: auth diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index 6cd43c5c0..18fb6dfaa 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -27,7 +27,7 @@ jobs: permissions: read-all runs-on: ubuntu-latest outputs: - approval-env: ${{ steps.auth.outputs.result }} + approval-env: ${{ steps.auth.outputs.approval-env }} steps: - name: Check Authorization id: auth From 25c2aa445c850eed89bbc1cb2b6a5de0dc52d1eb Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:03:04 -0500 Subject: [PATCH 135/279] fix: update allowed roles to include maintainer (#1704) --- .github/workflows/integration-test.yml | 2 +- .github/workflows/strands-command.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 789f4506a..a40eb0f45 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -20,7 +20,7 @@ jobs: with: skip-check: ${{ github.event_name == 'merge_group' }} username: ${{ github.event.pull_request.user.login || 'invalid' }} - allowed-roles: 'triage,write,admin' + allowed-roles: 'maintain,triage,write,admin' check-access-and-checkout: runs-on: ubuntu-latest diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index 18fb6dfaa..fc11efcb7 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -35,7 +35,7 @@ jobs: with: skip-check: ${{ github.event_name == 'workflow_dispatch' }} username: ${{ github.event.comment.user.login || 'invalid' }} - allowed-roles: 'triage,write,admin' + allowed-roles: 'maintain,triage,write,admin' setup-and-process: needs: [authorization-check] From 2281a204db41b59985af6c401fd46ccd66855f82 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:03:40 -0500 Subject: [PATCH 136/279] fix: propagate reasoningSignature on Gemini tool use (#1703) --- src/strands/event_loop/streaming.py | 4 + src/strands/models/gemini.py | 32 +++++-- src/strands/types/content.py | 4 +- src/strands/types/tools.py | 2 + tests/strands/event_loop/test_streaming.py | 37 ++++++++ tests/strands/models/test_gemini.py | 104 ++++++++++++++++++++- 6 files changed, 170 insertions(+), 13 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 954633807..b157f740e 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -186,6 +186,8 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: current_tool_use["toolUseId"] = tool_use_data["toolUseId"] current_tool_use["name"] = tool_use_data["name"] current_tool_use["input"] = "" + if "reasoningSignature" in tool_use_data: + current_tool_use["reasoningSignature"] = tool_use_data["reasoningSignature"] return current_tool_use @@ -286,6 +288,8 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: name=tool_use_name, input=current_tool_use["input"], ) + if "reasoningSignature" in current_tool_use: + tool_use["reasoningSignature"] = current_tool_use["reasoningSignature"] content.append({"toolUse": tool_use}) state["current_tool_use"] = {} diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 6a6535999..c94570293 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -3,6 +3,7 @@ - Docs: https://ai.google.dev/api """ +import base64 import json import logging import mimetypes @@ -14,7 +15,7 @@ from google import genai from typing_extensions import Required, Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -173,7 +174,7 @@ def _format_request_content_part( return genai.types.Part( text=content["reasoningContent"]["reasoningText"]["text"], thought=True, - thought_signature=thought_signature.encode("utf-8") if thought_signature else None, + thought_signature=base64.b64decode(thought_signature) if thought_signature else None, ) if "text" in content: @@ -202,14 +203,18 @@ def _format_request_content_part( ) if "toolUse" in content: - tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + tool_use_id = content["toolUse"]["toolUseId"] + tool_use_id_to_name[tool_use_id] = content["toolUse"]["name"] + + reasoning_signature = content["toolUse"].get("reasoningSignature") return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], - id=content["toolUse"]["toolUseId"], + id=tool_use_id, name=content["toolUse"]["name"], ), + thought_signature=base64.b64decode(reasoning_signature) if reasoning_signature else None, ) raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -349,13 +354,18 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: # Use Gemini's provided ID or generate one if missing tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}" + tool_use_start: ContentBlockStartToolUse = { + "name": function_call.name, + "toolUseId": tool_use_id, + } + if event["data"].thought_signature: + tool_use_start["reasoningSignature"] = base64.b64encode( + event["data"].thought_signature + ).decode("ascii") return { "contentBlockStart": { "start": { - "toolUse": { - "name": function_call.name, - "toolUseId": tool_use_id, - }, + "toolUse": tool_use_start, }, }, } @@ -379,7 +389,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "reasoningContent": { "text": event["data"].text, **( - {"signature": event["data"].thought_signature.decode("utf-8")} + { + "signature": base64.b64encode(event["data"].thought_signature).decode( + "ascii" + ) + } if event["data"].thought_signature else {} ), diff --git a/src/strands/types/content.py b/src/strands/types/content.py index d75dbb87f..2b0714bee 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -8,7 +8,7 @@ from typing import Literal -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent @@ -129,10 +129,12 @@ class ContentBlockStartToolUse(TypedDict): Attributes: name: The name of the tool that the model is requesting to use. toolUseId: The ID for the tool request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ name: str toolUseId: str + reasoningSignature: NotRequired[str] class ContentBlockStart(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 6fc0d703c..088c83bdb 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -58,11 +58,13 @@ class ToolUse(TypedDict): Can be any JSON-serializable type. name: The name of the tool to invoke. toolUseId: A unique identifier for this specific tool use request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ input: Any name: str toolUseId: str + reasoningSignature: NotRequired[str] class ToolResultContent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 0fe04f4b2..6d376450a 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -125,6 +125,10 @@ def test_handle_message_start(): {"start": {"toolUse": {"toolUseId": "test", "name": "test"}}}, {"toolUseId": "test", "name": "test", "input": ""}, ), + ( + {"start": {"toolUse": {"toolUseId": "test", "name": "test", "reasoningSignature": "YWJj"}}}, + {"toolUseId": "test", "name": "test", "input": "", "reasoningSignature": "YWJj"}, + ), ], ) def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use): @@ -310,6 +314,39 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), + # Tool Use - With reasoningSignature + ( + { + "content": [], + "current_tool_use": { + "toolUseId": "123", + "name": "test", + "input": '{"key": "value"}', + "reasoningSignature": "YWJj", + }, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + { + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "test", + "input": {"key": "value"}, + "reasoningSignature": "YWJj", + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), # Tool Use - Missing input ( { diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index d62c5a7c8..ba4b2b53f 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -203,7 +203,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): { "reasoningContent": { "reasoningText": { - "signature": "abc", + "signature": "YWJj", # base64 of "abc" "text": "reasoning_text", }, }, @@ -260,6 +260,51 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too @pytest.mark.asyncio async def test_stream_request_with_tool_use(gemini_client, model, model_id): + """Test toolUse with reasoningSignature is sent as function_call with thought_signature.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + "reasoningSignature": "YWJj", # base64 of "abc" + }, + }, + ], + }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [{"function_declarations": []}], + }, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {"expression": "2+2"}, + "id": "c1", + "name": "calculator", + }, + "thought_signature": "YWJj", + }, + ], + "role": "model", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_use_no_reasoning_signature(gemini_client, model, model_id): + """Test toolUse without reasoningSignature is sent as function_call without thought_signature.""" messages = [ { "role": "assistant", @@ -532,6 +577,55 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat assert tru_chunks == exp_chunks +@pytest.mark.asyncio +async def test_stream_response_tool_use_with_thought_signature(gemini_client, model, messages, agenerator, alist): + """Test that tool use responses with thought_signature include reasoningSignature.""" + gemini_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[ + genai.types.Part( + function_call=genai.types.FunctionCall( + args={"expression": "2+2"}, + id="c1", + name="calculator", + ), + thought_signature=b"abc", + ), + ], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + ] + ) + + tru_chunks = await alist(model.stream(messages)) + exp_chunks = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": { + "start": { + "toolUse": {"name": "calculator", "toolUseId": "c1", "reasoningSignature": "YWJj"}, + }, + }, + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, + ] + assert tru_chunks == exp_chunks + + @pytest.mark.asyncio async def test_stream_response_reasoning(gemini_client, model, messages, agenerator, alist): gemini_client.aio.models.generate_content_stream.return_value = agenerator( @@ -563,7 +657,7 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, @@ -622,7 +716,11 @@ async def test_stream_response_reasoning_and_text(gemini_client, model, messages exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "sig1", "text": "thinking about math"}}}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"signature": "c2lnMQ==", "text": "thinking about math"}} + } + }, {"contentBlockStop": {}}, {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"text": "2 + 2 = 4"}}}, From ac7244ad7e1d807060f8b9aac3fb401260293782 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:26:36 -0500 Subject: [PATCH 137/279] ci: bump actions/github-script from 7 to 8 (#1699) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/auto-strands-review.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/auto-strands-review.yml b/.github/workflows/auto-strands-review.yml index 44ca1ea03..ebcbc1870 100644 --- a/.github/workflows/auto-strands-review.yml +++ b/.github/workflows/auto-strands-review.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Trigger Strands Command Workflow - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | From 2d8c20ec6d6e731f0b618ed364a9257baf97174e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:27:01 -0500 Subject: [PATCH 138/279] ci: bump amannn/action-semantic-pull-request from 5 to 6 (#1684) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/pr-title.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index 14b18afa6..ada75b746 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -12,7 +12,7 @@ jobs: pull-requests: read steps: - name: Check PR title follows conventional commits - uses: amannn/action-semantic-pull-request@v5 + uses: amannn/action-semantic-pull-request@v6 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: From bdcc7179ff3dd654ed56969855ad37b96bb394b5 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Tue, 17 Feb 2026 07:23:35 -0800 Subject: [PATCH 139/279] fix: handle OpenAI model responses with tool calls and no other assistant content (#1562) --- src/strands/models/litellm.py | 2 +- src/strands/models/openai.py | 19 +++- tests/strands/models/test_litellm.py | 36 +++++++ tests/strands/models/test_openai.py | 138 ++++++++++++++++++++++++++- tests_integ/test_multiagent_swarm.py | 1 + 5 files changed, 190 insertions(+), 6 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index ec6579c58..be5337f0d 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -194,7 +194,7 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] @override def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index ab421e6c7..2b217ad91 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -204,10 +204,18 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> ], ) + formatted_contents = [cls.format_request_message_content(content) for content in contents] + + # If single text content, use string format for better model compatibility + if len(formatted_contents) == 1 and formatted_contents[0].get("type") == "text": + content: str | list[dict[str, Any]] = formatted_contents[0]["text"] + else: + content = formatted_contents + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": [cls.format_request_message_content(content) for content in contents], + "content": content, } @classmethod @@ -369,18 +377,21 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic formatted_message = { "role": message["role"], - "content": formatted_contents, + **({"content": formatted_contents} if formatted_contents else {}), **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), } formatted_messages.append(formatted_message) # Process tool messages to extract images into separate user messages # OpenAI API requires images to be in user role messages only + # All tool messages must be grouped together before any user messages with images + user_messages_with_images = [] for tool_msg in formatted_tool_messages: tool_msg_clean, user_msg_with_images = cls._split_tool_message_images(tool_msg) formatted_messages.append(tool_msg_clean) if user_msg_with_images: - formatted_messages.append(user_msg_with_images) + user_messages_with_images.append(user_msg_with_images) + formatted_messages.extend(user_messages_with_images) return formatted_messages @@ -407,7 +418,7 @@ def format_request_messages( formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) formatted_messages.extend(cls._format_regular_messages(messages)) - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + return [message for message in formatted_messages if "content" in message or "tool_calls" in message] def format_request( self, diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index f5e1837bf..9bb0e09ca 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -812,3 +812,39 @@ def __init__(self, usage): assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 5 assert metadata_events[0]["metadata"]["usage"]["totalTokens"] == 15 + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = LiteLLMModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 4f8652632..241c22b64 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -180,6 +180,23 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_message_single_text_returns_string(): + """Test that single text content is returned as string for model compatibility.""" + tool_result = { + "content": [{"text": '{"result": "success"}'}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": '{"result": "success"}', + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + def test_split_tool_message_images_with_image(): """Test that images are extracted from tool messages.""" tool_message = { @@ -441,7 +458,7 @@ def test_format_request_messages(system_prompt): ], }, { - "content": [{"text": "4", "type": "text"}], + "content": "4", "role": "tool", "tool_call_id": "c1", }, @@ -1397,3 +1414,122 @@ def test_format_request_filters_location_source_document(model, caplog): assert len(formatted_content) == 1 assert formatted_content[0]["type"] == "text" assert "Location sources are not supported by OpenAI" in caplog.text + + +def test_format_request_messages_with_tool_calls_no_content(): + """Test that assistant messages with only tool calls are included and have no content field.""" + messages = [ + {"role": "user", "content": [{"text": "Use the calculator"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + exp_result = [ + {"role": "user", "content": [{"text": "Use the calculator", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "function": {"arguments": '{"expression": "2+2"}', "name": "calculator"}, + "id": "c1", + "type": "function", + } + ], + }, + ] + assert tru_result == exp_result + + +def test_format_request_messages_multiple_tool_calls_with_images(): + """Test that multiple tool calls with image results are formatted correctly. + + OpenAI requires all tool response messages to immediately follow the assistant + message with tool_calls, before any other messages. When tools return images, + the images are moved to user messages, but these must come after ALL tool messages. + """ + messages = [ + {"role": "user", "content": [{"text": "Run the tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"input": {}, "name": "tool1", "toolUseId": "call_1"}}, + {"toolUse": {"input": {}, "name": "tool2", "toolUseId": "call_2"}}, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "call_1", + "content": [{"image": {"format": "png", "source": {"bytes": b"img1"}}}], + "status": "success", + } + }, + { + "toolResult": { + "toolUseId": "call_2", + "content": [{"image": {"format": "png", "source": {"bytes": b"img2"}}}], + "status": "success", + } + }, + ], + }, + ] + + tru_result = OpenAIModel.format_request_messages(messages) + + image_placeholder = ( + "Tool successfully returned an image. The image is being provided in the following user message." + ) + exp_result = [ + {"role": "user", "content": [{"text": "Run the tools", "type": "text"}]}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"arguments": "{}", "name": "tool1"}, "id": "call_1", "type": "function"}, + {"function": {"arguments": "{}", "name": "tool2"}, "id": "call_2", "type": "function"}, + ], + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "tool", + "tool_call_id": "call_2", + "content": [{"type": "text", "text": image_placeholder}], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": "data:image/png;base64,aW1nMQ=="}, + "type": "image_url", + } + ], + }, + { + "role": "user", + "content": [ + { + "image_url": {"detail": "auto", "format": "image/png", "url": "data:image/png;base64,aW1nMg=="}, + "type": "image_url", + } + ], + }, + ] + assert tru_result == exp_result diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index e9738d3d9..a244bf753 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -113,6 +113,7 @@ def capture_first_node(self, event): return VerifyHook() +@pytest.mark.timeout(120) def test_swarm_execution_with_string(researcher_agent, analyst_agent, writer_agent, hook_provider): """Test swarm execution with string input.""" # Create the swarm From 4e829574fbc4e63fe5d97a5b5afda5f238089ded Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 17 Feb 2026 15:02:21 -0500 Subject: [PATCH 140/279] fix: Update finalize condition for workflow execution (#1708) --- .github/workflows/strands-command.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml index fc11efcb7..496ce025b 100644 --- a/.github/workflows/strands-command.yml +++ b/.github/workflows/strands-command.yml @@ -79,7 +79,7 @@ jobs: write_permission: 'false' finalize: - if: always() + if: always() && (startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch') needs: [setup-and-process, execute-readonly-agent] permissions: contents: write From cb3d359b68cc670cbe225fc0c81365124dc121be Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Wed, 18 Feb 2026 06:53:08 -0800 Subject: [PATCH 141/279] fix: Upgrade mcp minimum dependency to 1.23.0 for Tasks support (#1674) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 048fab88f..d1679a91f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "botocore>=1.29.0,<2.0.0", "docstring_parser>=0.15,<1.0", "jsonschema>=4.0.0,<5.0.0", - "mcp>=1.11.0,<2.0.0", + "mcp>=1.23.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", From a01d9337e28b64a1b6c1b1d929a8f6becfe8068d Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:27:16 -0500 Subject: [PATCH 142/279] feat(agent): add concurrent_invocation_mode parameter (#1707) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow --- src/strands/agent/agent.py | 25 ++++++--- src/strands/types/agent.py | 17 ++++++ tests/strands/agent/test_agent.py | 89 ++++++++++++++++++++++++++++--- 3 files changed, 116 insertions(+), 15 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 567a92b4a..e9739f473 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -54,7 +54,7 @@ from ..tools.structured_output._structured_output_context import StructuredOutputContext from ..tools.watcher import ToolWatcher from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent -from ..types.agent import AgentInput +from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue @@ -129,6 +129,7 @@ def __init__( structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY, + concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW, ): """Initialize the Agent with the specified configuration. @@ -186,6 +187,11 @@ def __init__( retry_strategy: Strategy for retrying model calls on throttling or other transient errors. Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. Implement a custom HookProvider for custom retry logic, or pass None to disable retries. + concurrent_invocation_mode: Mode controlling concurrent invocation behavior. + Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted. + Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations. + Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided + only for advanced use cases where the caller understands the risks. Raises: ValueError: If agent id contains path separators. @@ -263,6 +269,7 @@ def __init__( # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads, so asyncio.Lock wouldn't work self._invocation_lock = threading.Lock() + self._concurrent_invocation_mode = concurrent_invocation_mode # In the future, we'll have a RetryStrategy base class but until # that API is determined we only allow ModelRetryStrategy @@ -622,14 +629,15 @@ async def stream_async( yield event["data"] ``` """ - # Acquire lock to prevent concurrent invocations + # Conditionally acquire lock based on concurrent_invocation_mode # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads - acquired = self._invocation_lock.acquire(blocking=False) - if not acquired: - raise ConcurrencyException( - "Agent is already processing a request. Concurrent invocations are not supported." - ) + if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW: + lock_acquired = self._invocation_lock.acquire(blocking=False) + if not lock_acquired: + raise ConcurrencyException( + "Agent is already processing a request. Concurrent invocations are not supported." + ) try: self._interrupt_state.resume(prompt) @@ -678,7 +686,8 @@ async def stream_async( raise finally: - self._invocation_lock.release() + if self._invocation_lock.locked(): + self._invocation_lock.release() async def _run_loop( self, diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index aa69149a6..cda01f8aa 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -3,9 +3,26 @@ This module defines the types used for an Agent. """ +from enum import Enum from typing import TypeAlias from .content import ContentBlock, Messages from .interrupt import InterruptResponseContent AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None + + +class ConcurrentInvocationMode(str, Enum): + """Mode controlling concurrent invocation behavior. + + Values: + THROW: Raises ConcurrencyException if concurrent invocation is attempted (default). + UNSAFE_REENTRANT: Allows concurrent invocations without locking. + + Warning: + The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is + provided only for advanced use cases where the caller understands the risks. + """ + + THROW = "throw" + UNSAFE_REENTRANT = "unsafe_reentrant" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index eb039185c..d95d26f92 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -26,6 +26,7 @@ from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent +from strands.types.agent import ConcurrentInvocationMode from strands.types.content import Messages from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType @@ -2231,20 +2232,17 @@ def test_agent_concurrent_call_raises_exception(): {"role": "assistant", "content": [{"text": "world"}]}, ] ) - agent = Agent(model=model) + agent = Agent(model=model, concurrent_invocation_mode="throw") results = [] errors = [] - lock = threading.Lock() def invoke(): try: result = agent("test") - with lock: - results.append(result) + results.append(result) except ConcurrencyException as e: - with lock: - errors.append(e) + errors.append(e) # Start first thread and wait for it to begin streaming t1 = threading.Thread(target=invoke) @@ -2282,7 +2280,7 @@ def test_agent_concurrent_structured_output_raises_exception(): {"role": "assistant", "content": [{"text": "response2"}]}, ], ) - agent = Agent(model=model) + agent = Agent(model=model, concurrent_invocation_mode="throw") results = [] errors = [] @@ -2320,6 +2318,83 @@ def invoke(): assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() +def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode(): + """Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'.""" + model = SyncEventMockedModel( + [ + {"role": "assistant", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "world"}]}, + ] + ) + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + + results = [] + errors = [] + lock = threading.Lock() + + def invoke(): + try: + result = agent("test") + with lock: + results.append(result) + except ConcurrencyException as e: + with lock: + errors.append(e) + + # Start first thread and wait for it to begin streaming + t1 = threading.Thread(target=invoke) + t1.start() + model.started_event.wait() # Wait until first thread is in the model.stream() + + # Start second thread while first is still running + t2 = threading.Thread(target=invoke) + t2.start() + + # Let both threads proceed + model.proceed_event.set() + t1.join() + t2.join() + + # Both should succeed, no ConcurrencyException raised + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" + assert len(results) == 2, f"Expected 2 successes, got {len(results)}" + + +def test_agent_concurrent_invocation_mode_default_is_throw(): + """Test that the default concurrent_invocation_mode is 'throw'.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + agent = Agent(model=model) + + # Verify the default mode + assert agent._concurrent_invocation_mode == "throw" + + +def test_agent_concurrent_invocation_mode_stores_value(): + """Test that concurrent_invocation_mode is stored correctly as instance variable.""" + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + agent_throw = Agent(model=model, concurrent_invocation_mode="throw") + assert agent_throw._concurrent_invocation_mode == "throw" + + agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + + +def test_agent_concurrent_invocation_mode_accepts_enum(): + """Test that concurrent_invocation_mode accepts enum values as well as strings.""" + + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) + + # Using enum values + agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW) + assert agent_throw._concurrent_invocation_mode == "throw" + assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW + + agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT) + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" + assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT + + @pytest.mark.asyncio async def test_agent_sequential_invocations_work(): """Test that sequential invocations work correctly after lock is released.""" From 0eae8a761ac16f3e6e25613b1a519e8e3408b69c Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:39:42 -0500 Subject: [PATCH 143/279] test: coverage for python 3.14 (#1178) --- .github/workflows/test-lint.yml | 8 +++++++- pyproject.toml | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 8f393d5de..89cc459de 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -31,6 +31,9 @@ jobs: - os: ubuntu-latest os-name: 'linux' python-version: "3.13" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.14" # Windows - os: windows-latest os-name: 'windows' @@ -44,10 +47,13 @@ jobs: - os: windows-latest os-name: 'windows' python-version: "3.13" + - os: windows-latest + os-name: 'windows' + python-version: "3.14" # MacOS - latest only; not enough runners for macOS - os: macos-latest os-name: 'macOS' - python-version: "3.13" + python-version: "3.14" fail-fast: true runs-on: ${{ matrix.os }} env: diff --git a/pyproject.toml b/pyproject.toml index d1679a91f..2aa417b18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] @@ -154,7 +155,7 @@ dependencies = [ ] [[tool.hatch.envs.hatch-test.matrix]] -python = ["3.13", "3.12", "3.11", "3.10"] +python = ["3.14", "3.13", "3.12", "3.11", "3.10"] [tool.hatch.envs.hatch-test.scripts] run = "pytest{env:HATCH_TEST_ARGS:} {args}" # Run with: hatch test From 0a318480d9c2eccec59bb20f69276204fd53f594 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Feb 2026 12:03:49 -0500 Subject: [PATCH 144/279] feat(agent): add add_hook convenience method for hook callback registration (#1706) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Clare Liguori --- src/strands/agent/agent.py | 43 ++++++++++++++ src/strands/hooks/registry.py | 87 ++++++++++++++++++++++++++-- tests/strands/agent/test_agent.py | 71 ++++++++++++++++++++++- tests/strands/hooks/test_registry.py | 81 ++++++++++++++++++++++++++ 4 files changed, 277 insertions(+), 5 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e9739f473..e199608a2 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -37,10 +37,12 @@ AfterInvocationEvent, AgentInitializedEvent, BeforeInvocationEvent, + HookCallback, HookProvider, HookRegistry, MessageAddedEvent, ) +from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model @@ -574,6 +576,47 @@ def cleanup(self) -> None: """ self.tool_registry.cleanup() + def add_hook( + self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any] + ) -> None: + """Register a callback function for a specific event type. + + This method supports two call patterns: + 1. ``add_hook(callback)`` - Event type inferred from callback's type hint + 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + + Callbacks can be either synchronous or asynchronous functions. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type of events this callback should handle. + If not provided, the event type will be inferred from the callback's + first parameter type hint. + **kwargs: Additional arguments (ignored). + + + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints. + + Example: + ```python + def log_model_call(event: BeforeModelCallEvent) -> None: + print(f"Calling model for agent: {event.agent.name}") + + agent = Agent() + + # With event type inferred from type hint + agent.add_hook(log_model_call) + + # With explicit event type + agent.add_hook(log_model_call, BeforeModelCallEvent) + ``` + Docs: + https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ + """ + self.hooks.add_callback(event_type, callback) + def __del__(self) -> None: """Clean up resources when agent is garbage collected.""" # __del__ is called even when an exception is thrown in the constructor, diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 309e3ba76..2f465a751 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -11,7 +11,15 @@ import logging from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + TypeVar, + get_type_hints, + runtime_checkable, +) from ..interrupt import Interrupt, InterruptException @@ -157,28 +165,99 @@ def __init__(self) -> None: """Initialize an empty hook registry.""" self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback( + self, + event_type: type[TEvent] | None, + callback: HookCallback[TEvent], + ) -> None: """Register a callback function for a specific event type. + If ``event_type`` is None, then this will check the callback handler type hint + for the lifecycle event type. + Args: event_type: The class type of events this callback should handle. callback: The callback function to invoke when events of this type occur. + Raises: + ValueError: If event_type is not provided and cannot be inferred from + the callback's type hints, or if AgentInitializedEvent is registered + with an async callback. + Example: ```python def my_handler(event: StartRequestEvent): print("Request started") + # With explicit event type registry.add_callback(StartRequestEvent, my_handler) + + # With event type inferred from type hint + registry.add_callback(None, my_handler) ``` """ + resolved_event_type: type[TEvent] + + # Support both add_callback(None, callback) and add_callback(event_type, callback) + if event_type is None: + # callback provided but event_type is None - infer it + resolved_event_type = self._infer_event_type(callback) + else: + resolved_event_type = event_type + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - callbacks = self._registered_callbacks.setdefault(event_type, []) + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) callbacks.append(callback) + def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: + """Infer the event type from a callback's type hints. + + Args: + callback: The callback function to inspect. + + Returns: + The event type inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints. + """ + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError( + "callback has no parameters | cannot infer event type, please provide event_type explicitly" + ) + + first_param = params[0] + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return type_hint # type: ignore[return-value] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) + def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d95d26f92..587735cec 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -20,7 +20,7 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from strands.hooks import BeforeToolCallEvent +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager @@ -2550,3 +2550,72 @@ def agent_tool(tool_context: ToolContext) -> str: ], "role": "user", } + + +def test_agent_add_hook_registers_callback(): + """Test that add_hook registers a callback with the hooks registry.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + agent.add_hook(callback, BeforeModelCallEvent) + + # Verify callback was registered by checking it gets invoked + agent("test prompt") + callback.assert_called_once() + # Verify it was called with the correct event type + call_args = callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_delegates_to_hooks_add_callback(): + """Test that add_hook delegates to self.hooks.add_callback.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback = unittest.mock.Mock() + + # Spy on the hooks.add_callback method + with unittest.mock.patch.object(agent.hooks, "add_callback") as mock_add_callback: + agent.add_hook(callback, BeforeInvocationEvent) + mock_add_callback.assert_called_once_with(BeforeInvocationEvent, callback) + + +@pytest.mark.asyncio +async def test_agent_add_hook_works_with_async_callback(): + """Test that add_hook works with async callbacks.""" + + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + async_callback = unittest.mock.AsyncMock() + + agent.add_hook(async_callback, BeforeModelCallEvent) + + # Use stream_async to invoke the agent with async support + _ = [event async for event in agent.stream_async("test prompt")] + async_callback.assert_called_once() + # Verify it was called with the correct event type + call_args = async_callback.call_args[0] + assert isinstance(call_args[0], BeforeModelCallEvent) + + +def test_agent_add_hook_infers_event_type_from_callback(): + """Test that add_hook infers event type from callback type hint.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + callback_invoked = [] + + def typed_callback(event: BeforeModelCallEvent) -> None: + callback_invoked.append(event) + + agent.add_hook(typed_callback) + agent("test prompt") + + assert len(callback_invoked) == 1 + assert isinstance(callback_invoked[0], BeforeModelCallEvent) + + +def test_agent_add_hook_raises_error_when_no_type_hint(): + """Test that add_hook raises error when event type cannot be inferred.""" + agent = Agent(model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}])) + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + agent.add_hook(untyped_callback) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 3daf41734..8c8b794f7 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -87,3 +87,84 @@ def test_hook_registry_invoke_callbacks_coroutine(registry, agent): with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) + + +def test_hook_registry_add_callback_infers_event_type(registry): + """Test that add_callback infers event type from callback type hint.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + # Register without explicit event_type - should infer from type hint + registry.add_callback(None, typed_callback) + + # Verify callback was registered + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_no_type_hint(registry): + """Test that add_callback raises error when type hint is missing.""" + + def untyped_callback(event): + pass + + with pytest.raises(ValueError, match="cannot infer event type"): + registry.add_callback(None, untyped_callback) + + +def test_hook_registry_add_callback_raises_error_invalid_type_hint(registry): + """Test that add_callback raises error when type hint is not a BaseHookEvent subclass.""" + + def invalid_callback(event: str) -> None: + pass + + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + registry.add_callback(None, invalid_callback) + + +def test_hook_registry_add_callback_raises_error_no_parameters(registry): + """Test that add_callback raises error when callback has no parameters.""" + + def no_param_callback() -> None: + pass + + with pytest.raises(ValueError, match="callback has no parameters"): + registry.add_callback(None, no_param_callback) + + +def test_hook_registry_add_callback_infers_event_type_when_callback_provided_without_event_type(registry): + """Test that add_callback infers event type when callback is provided but event_type is None.""" + + def typed_callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, typed_callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert typed_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_with_explicit_event_type_and_callback(registry): + """Test that add_callback works with explicit event_type and callback.""" + + def callback(event: BeforeInvocationEvent) -> None: + pass + + registry.add_callback(BeforeInvocationEvent, callback) + + assert BeforeInvocationEvent in registry._registered_callbacks + assert callback in registry._registered_callbacks[BeforeInvocationEvent] + + +def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): + """Test that add_callback raises error when get_type_hints fails.""" + + class BadCallback: + def __call__(self, event: "NonExistentType") -> None: # noqa: F821 + pass + + callback = BadCallback() + + with pytest.raises(ValueError, match="failed to get type hints for callback"): + registry.add_callback(None, callback) From df98ee154426e5d2b40c9e233b694965222cc68d Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:33:13 -0500 Subject: [PATCH 145/279] fix: update region for agentcore in our new account (#1715) --- .github/workflows/issue-responder.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/issue-responder.yml b/.github/workflows/issue-responder.yml index 2efa03117..5b3ad7305 100644 --- a/.github/workflows/issue-responder.yml +++ b/.github/workflows/issue-responder.yml @@ -17,7 +17,7 @@ jobs: uses: aws-actions/configure-aws-credentials@v6 with: role-to-assume: ${{ secrets.STRANDS_AGENTCORE_ACTIONS_ROLE }} - aws-region: us-west-2 + aws-region: us-east-1 - name: Invoke AgentCore with issue details env: GH_ISSUE_AGENTCORE_RUNTIME_ARN: ${{ secrets.GH_ISSUE_AGENTCORE_RUNTIME_ARN }} @@ -48,7 +48,7 @@ jobs: console.log("Invoking AgentCore with payload:"); console.log(JSON.stringify(JSON.parse(payload), null, 2)); - const client = new BedrockAgentCoreClient({ region: "us-west-2" }); + const client = new BedrockAgentCoreClient({ region: "us-east-1" }); const sessionId = `github-issue-${process.env.ISSUE_NUMBER}-${Date.now()}-${Math.random().toString(36).slice(2)}`; From db6cd98cbc564c5f31622f49e183e1e7544bff40 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Feb 2026 15:48:53 -0500 Subject: [PATCH 146/279] fix: remove test that fails for python 3.14 (#1717) --- tests/strands/hooks/test_registry.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 8c8b794f7..5331bfa43 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -155,16 +155,3 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] - - -def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): - """Test that add_callback raises error when get_type_hints fails.""" - - class BadCallback: - def __call__(self, event: "NonExistentType") -> None: # noqa: F821 - pass - - callback = BadCallback() - - with pytest.raises(ValueError, match="failed to get type hints for callback"): - registry.add_callback(None, callback) From 2456b71651aaa26630ea1a4d5f14788109b65bf1 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 19 Feb 2026 09:50:02 -0500 Subject: [PATCH 147/279] feat(hooks): support union types and list of types for add_hook (#1719) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/agent.py | 29 +++-- src/strands/hooks/registry.py | 106 ++++++++++++++---- tests/strands/hooks/test_registry.py | 155 ++++++++++++++++++++++++++- 3 files changed, 260 insertions(+), 30 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e199608a2..7350ab7ed 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -577,27 +577,30 @@ def cleanup(self) -> None: self.tool_registry.cleanup() def add_hook( - self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any] + self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None ) -> None: """Register a callback function for a specific event type. - This method supports two call patterns: + This method supports multiple call patterns: 1. ``add_hook(callback)`` - Event type inferred from callback's type hint 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + 3. ``add_hook(callback, [TypeA, TypeB])`` - Register for multiple event types + + When the callback's type hint is a union type (``A | B`` or ``Union[A, B]``), + the callback is automatically registered for each event type in the union. Callbacks can be either synchronous or asynchronous functions. Args: callback: The callback function to invoke when events of this type occur. - event_type: The class type of events this callback should handle. - If not provided, the event type will be inferred from the callback's - first parameter type hint. - **kwargs: Additional arguments (ignored). - + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. If a list is provided, + the callback is registered for each type in the list. Raises: ValueError: If event_type is not provided and cannot be inferred from - the callback's type hints. + the callback's type hints, or if the event_type list is empty. Example: ```python @@ -611,6 +614,16 @@ def log_model_call(event: BeforeModelCallEvent) -> None: # With explicit event type agent.add_hook(log_model_call, BeforeModelCallEvent) + + # With union type hint (registers for all types) + def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(log_event) + + # With list of event types + def multi_handler(event) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(multi_handler, [BeforeModelCallEvent, AfterModelCallEvent]) ``` Docs: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 2f465a751..886ea5644 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,6 +9,7 @@ import inspect import logging +import types from collections.abc import Awaitable, Generator from dataclasses import dataclass from typing import ( @@ -17,6 +18,10 @@ Generic, Protocol, TypeVar, + Union, + cast, + get_args, + get_origin, get_type_hints, runtime_checkable, ) @@ -167,22 +172,27 @@ def __init__(self) -> None: def add_callback( self, - event_type: type[TEvent] | None, + event_type: type[TEvent] | list[type[TEvent]] | None, callback: HookCallback[TEvent], ) -> None: """Register a callback function for a specific event type. If ``event_type`` is None, then this will check the callback handler type hint - for the lifecycle event type. + for the lifecycle event type. Union types (``A | B`` or ``Union[A, B]``) in + type hints will register the callback for each event type in the union. + + If ``event_type`` is a list, the callback will be registered for each event + type in the list (duplicates are ignored). Args: - event_type: The class type of events this callback should handle. + event_type: The lifecycle event type(s) this callback should handle. + Can be a single type, a list of types, or None to infer from type hints. callback: The callback function to invoke when events of this type occur. Raises: ValueError: If event_type is not provided and cannot be inferred from the callback's type hints, or if AgentInitializedEvent is registered - with an async callback. + with an async callback, or if the event_type list is empty. Example: ```python @@ -194,35 +204,77 @@ def my_handler(event: StartRequestEvent): # With event type inferred from type hint registry.add_callback(None, my_handler) + + # With union type hint (registers for both types) + def union_handler(event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + registry.add_callback(None, union_handler) + + # With list of event types + def multi_handler(event): + print(f"Event: {type(event).__name__}") + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], multi_handler) ``` """ - resolved_event_type: type[TEvent] - - # Support both add_callback(None, callback) and add_callback(event_type, callback) - if event_type is None: - # callback provided but event_type is None - infer it - resolved_event_type = self._infer_event_type(callback) + resolved_event_types: list[type[TEvent]] + + # Handle list of event types + if isinstance(event_type, list): + if not event_type: + raise ValueError("event_type list cannot be empty") + resolved_event_types = self._validate_event_type_list(event_type) + elif event_type is None: + # Infer event type(s) from callback type hints + resolved_event_types = self._infer_event_types(callback) else: - resolved_event_type = event_type + # Single event type provided explicitly + resolved_event_types = [event_type] - # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): - raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + # Deduplicate event types while preserving order + unique_event_types: set[type[TEvent]] = set(resolved_event_types) - callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) - callbacks.append(callback) + # Register callback for each event type + for resolved_event_type in unique_event_types: + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: - """Infer the event type from a callback's type hints. + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(callback) + + def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]: + """Validate that all types in a list are valid BaseHookEvent subclasses. + + Args: + event_types: List of event types to validate. + + Returns: + The validated list of event types. + + Raises: + ValueError: If any type is not a valid BaseHookEvent subclass. + """ + validated: list[type[TEvent]] = [] + for et in event_types: + if not (isinstance(et, type) and issubclass(et, BaseHookEvent)): + raise ValueError(f"Invalid event type: {et} | must be a subclass of BaseHookEvent") + validated.append(et) + return validated + + def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). Args: callback: The callback function to inspect. Returns: - The event type inferred from the callback's first parameter type hint. + A list of event types inferred from the callback's first parameter type hint. Raises: - ValueError: If the event type cannot be inferred from the callback's type hints. + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. """ try: hints = get_type_hints(callback) @@ -250,9 +302,21 @@ def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: "cannot infer event type, please provide event_type explicitly" ) + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast(type[TEvent], arg)) + return event_types + # Handle single type if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return type_hint # type: ignore[return-value] + return [cast(type[TEvent], type_hint)] raise ValueError( f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 5331bfa43..79829b92b 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -1,8 +1,16 @@ import unittest.mock +from typing import Union import pytest -from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.hooks import ( + AfterModelCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookRegistry, +) from strands.interrupt import Interrupt, _InterruptState @@ -155,3 +163,148 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] + +# ========== Tests for union type support ========== + + +def test_hook_registry_add_callback_infers_union_types_pipe_syntax(registry): + """Test that add_callback registers callback for each type in A | B union.""" + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_infers_union_types_union_syntax(registry): + """Test that add_callback registers callback for each type in Union[A, B].""" + + def union_callback(event: Union[BeforeModelCallEvent, AfterModelCallEvent]) -> None: # noqa: UP007 + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_union_with_none_raises_error(registry): + """Test that add_callback raises error when union contains None.""" + + def callback_with_none(event: BeforeModelCallEvent | None) -> None: + pass + + with pytest.raises(ValueError, match="None is not a valid event type"): + registry.add_callback(None, callback_with_none) + + +def test_hook_registry_add_callback_union_with_invalid_type_raises_error(registry): + """Test that add_callback raises error when union contains non-BaseHookEvent type.""" + + def callback_with_invalid_type(event: BeforeModelCallEvent | str) -> None: + pass + + with pytest.raises(ValueError, match="Invalid type in union"): + registry.add_callback(None, callback_with_invalid_type) + + +def test_hook_registry_add_callback_union_multiple_types(registry): + """Test that add_callback handles union with more than two types.""" + + def multi_union_callback(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, multi_union_callback) + + # Callback should be registered for all three event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert BeforeInvocationEvent in registry._registered_callbacks + assert multi_union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[AfterModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for list of types support ========== + + +def test_hook_registry_add_callback_with_list_of_types(registry): + """Test that add_callback registers callback for each type in a list.""" + + def my_callback(event) -> None: + pass + + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert my_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert my_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_with_list_deduplicates(registry): + """Test that add_callback deduplicates event types in a list.""" + + def my_callback(event) -> None: + pass + + # Same type appears multiple times + registry.add_callback([BeforeModelCallEvent, BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered only once per event type + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +def test_hook_registry_add_callback_with_list_validates_types(registry): + """Test that add_callback validates all types in a list are BaseHookEvent subclasses.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="Invalid event type"): + registry.add_callback([BeforeModelCallEvent, str], my_callback) + + +def test_hook_registry_add_callback_with_empty_list_raises_error(registry): + """Test that add_callback raises error when given an empty list.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="event_type list cannot be empty"): + registry.add_callback([], my_callback) + + +@pytest.mark.asyncio +async def test_hook_registry_union_callback_invoked_for_each_type(registry, agent): + """Test that a union-registered callback is invoked correctly for each event type.""" + call_count = {"before": 0, "after": 0} + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + if isinstance(event, BeforeModelCallEvent): + call_count["before"] += 1 + elif isinstance(event, AfterModelCallEvent): + call_count["after"] += 1 + + registry.add_callback(None, union_callback) + + # Invoke BeforeModelCallEvent + before_event = BeforeModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(before_event) + assert call_count["before"] == 1 + + # Invoke AfterModelCallEvent + after_event = AfterModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(after_event) + assert call_count["after"] == 1 From a5d26e7edbab321993e5e085b752af8dd1961d2f Mon Sep 17 00:00:00 2001 From: mehtarac Date: Thu, 19 Feb 2026 11:40:57 -0500 Subject: [PATCH 148/279] feat: make pyaudio an optional dependency by lazy loading (#1731) --- README.md | 14 +++++++++++++- pyproject.toml | 6 ++++-- src/strands/experimental/bidi/__init__.py | 23 ++++++++++++++++++----- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8e4d9d0e8..9ee7f6c56 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,16 @@ Build real-time voice and audio conversations with persistent streaming connecti - Google Gemini Live - OpenAI Realtime API +**Installation:** + +```bash +# Server-side only (no audio I/O dependencies) +pip install strands-agents[bidi] + +# With audio I/O support (includes PyAudio dependency) +pip install strands-agents[bidi,bidi-io] +``` + **Quick Example:** ```python @@ -223,7 +233,7 @@ async def main(): model = BidiNovaSonicModel() agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) - # Setup audio and text I/O + # Setup audio and text I/O (requires bidi-io extra) audio_io = BidiAudioIO() text_io = BidiTextIO() @@ -238,6 +248,8 @@ if __name__ == "__main__": asyncio.run(main()) ``` +> **Note**: `BidiAudioIO` and `BidiTextIO` require the `bidi-io` extra. For server-side deployments where audio I/O is handled by clients (browsers, mobile apps), install only `strands-agents[bidi]` and implement custom input/output handlers using the `BidiInput` and `BidiOutput` protocols. + **Configuration Options:** ```python diff --git a/pyproject.toml b/pyproject.toml index 2aa417b18..b53194486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,15 +73,17 @@ a2a = [ bidi = [ "aws_sdk_bedrock_runtime; python_version>='3.12'", + "smithy-aws-core>=0.0.1; python_version>='3.12'", +] +bidi-io = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", - "smithy-aws-core>=0.0.1; python_version>='3.12'", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<17.0.0"] all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] -bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] +bidi-all = ["strands-agents[a2a,bidi,bidi-io,bidi-gemini,bidi-openai,docs,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 1c0e74aae..99fbacce1 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,5 +1,7 @@ """Bidirectional streaming package.""" +from typing import Any + # Main components - Primary user interface # Re-export standard agent events for tool handling from ...types._events import ( @@ -9,9 +11,6 @@ ) from .agent.agent import BidiAgent -# IO channels - Hardware abstraction -from .io.audio import BidiAudioIO - # Model interface (for custom implementations) from .models.model import BidiModel @@ -40,8 +39,6 @@ __all__ = [ # Main interface "BidiAgent", - # IO channels - "BidiAudioIO", # Built-in tools "stop_conversation", # Input Event types @@ -68,3 +65,19 @@ # Model interface "BidiModel", ] + + +def __getattr__(name: str) -> Any: + """Lazy load IO implementations only when accessed. + + This defers the import of optional dependencies until actually needed. + """ + if name == "BidiAudioIO": + from .io.audio import BidiAudioIO + + return BidiAudioIO + if name == "BidiTextIO": + from .io.text import BidiTextIO + + return BidiTextIO + raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") From 029c77acae6f559212906a1ffe09368279ad976f Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 19 Feb 2026 12:07:46 -0500 Subject: [PATCH 149/279] feat(hooks): add Plugin Protocol for agent extensibility (#1733) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- AGENTS.md | 5 + src/strands/__init__.py | 2 + src/strands/plugins/__init__.py | 25 ++++ src/strands/plugins/plugin.py | 43 ++++++ src/strands/plugins/registry.py | 73 ++++++++++ tests/strands/plugins/__init__.py | 1 + tests/strands/plugins/test_plugins.py | 192 ++++++++++++++++++++++++++ 7 files changed, 341 insertions(+) create mode 100644 src/strands/plugins/__init__.py create mode 100644 src/strands/plugins/plugin.py create mode 100644 src/strands/plugins/registry.py create mode 100644 tests/strands/plugins/__init__.py create mode 100644 tests/strands/plugins/test_plugins.py diff --git a/AGENTS.md b/AGENTS.md index 9199d50fa..6cd2155c1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -126,6 +126,10 @@ strands-agents/ │ │ ├── events.py # Hook event definitions │ │ └── registry.py # Hook registration │ │ +│ ├── plugins/ # Plugin system +│ │ ├── plugin.py # Plugin Protocol definition +│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling │ │ @@ -171,6 +175,7 @@ strands-agents/ │ ├── session/ │ ├── telemetry/ │ ├── hooks/ +│ ├── plugins/ │ ├── handlers/ │ ├── experimental/ │ └── utils/ diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 6026d4240..be939d5b1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,6 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy +from .plugins import Plugin from .tools.decorator import tool from .types.tools import ToolContext @@ -13,6 +14,7 @@ "agent", "models", "ModelRetryStrategy", + "Plugin", "tool", "ToolContext", "types", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py new file mode 100644 index 000000000..33922e952 --- /dev/null +++ b/src/strands/plugins/__init__.py @@ -0,0 +1,25 @@ +"""Plugin system for extending agent functionality. + +This module provides a composable mechanism for building objects that can +extend agent behavior through a standardized initialization pattern. + +Example Usage: + ```python + from strands.plugins import Plugin + + class LoggingPlugin: + name = "logging" + + def init_plugin(self, agent: Agent) -> None: + agent.add_hook(self.on_model_call, BeforeModelCallEvent) + + def on_model_call(self, event: BeforeModelCallEvent) -> None: + print(f"Model called for {event.agent.name}") + ``` +""" + +from .plugin import Plugin + +__all__ = [ + "Plugin", +] diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py new file mode 100644 index 000000000..b6a8fd1d9 --- /dev/null +++ b/src/strands/plugins/plugin.py @@ -0,0 +1,43 @@ +"""Plugin protocol for extending agent functionality. + +This module defines the Plugin Protocol, which provides a composable way to +add behavior changes to agents through a standardized initialization pattern. +""" + +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ..agent import Agent + + +@runtime_checkable +class Plugin(Protocol): + """Protocol for objects that extend agent functionality. + + Plugins provide a composable way to add behavior changes to agents. + They are initialized with an agent instance and can register hooks, + modify agent attributes, or perform other setup tasks. + + Attributes: + name: A stable string identifier for the plugin + + Example: + ```python + class MyPlugin: + name = "my-plugin" + + def init_plugin(self, agent: Agent) -> None: + agent.add_hook(self.on_model_call, BeforeModelCallEvent) + ``` + """ + + name: str + + def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: + """Initialize the plugin with an agent instance. + + Args: + agent: The agent instance to extend. + """ + ... diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py new file mode 100644 index 000000000..ffd73c2f4 --- /dev/null +++ b/src/strands/plugins/registry.py @@ -0,0 +1,73 @@ +"""Plugin registry for managing plugins attached to an agent. + +This module provides the _PluginRegistry class for tracking and managing +plugins that have been initialized with an agent instance. +""" + +import inspect +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, cast + +from .._async import run_async +from .plugin import Plugin + +if TYPE_CHECKING: + from ..agent import Agent + +logger = logging.getLogger(__name__) + + +class _PluginRegistry: + """Registry for managing plugins attached to an agent. + + The _PluginRegistry tracks plugins that have been initialized with an agent, + providing methods to add plugins and invoke their initialization. + + Example: + ```python + registry = _PluginRegistry(agent) + + class MyPlugin: + name = "my-plugin" + + def init_plugin(self, agent: Agent) -> None: + pass + + plugin = MyPlugin() + registry.add_and_init(plugin) + ``` + """ + + def __init__(self, agent: "Agent") -> None: + """Initialize a plugin registry with an agent reference. + + Args: + agent: The agent instance that plugins will be initialized with. + """ + self._agent = agent + self._plugins: dict[str, Plugin] = {} + + def add_and_init(self, plugin: Plugin) -> None: + """Add and initialize a plugin with the agent. + + This method registers the plugin and calls its init_plugin method. + Handles both sync and async init_plugin implementations automatically. + + Args: + plugin: The plugin to add and initialize. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered") + + logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) + self._plugins[plugin.name] = plugin + + if inspect.iscoroutinefunction(plugin.init_plugin): + async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_plugin) + run_async(lambda: async_plugin_init(self._agent)) + else: + plugin.init_plugin(self._agent) diff --git a/tests/strands/plugins/__init__.py b/tests/strands/plugins/__init__.py new file mode 100644 index 000000000..6b722411e --- /dev/null +++ b/tests/strands/plugins/__init__.py @@ -0,0 +1 @@ +"""Tests for the plugins module.""" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py new file mode 100644 index 000000000..90f6a2545 --- /dev/null +++ b/tests/strands/plugins/test_plugins.py @@ -0,0 +1,192 @@ +"""Tests for the plugin system.""" + +import unittest.mock + +import pytest + +from strands.plugins import Plugin +from strands.plugins.registry import _PluginRegistry + +# Plugin Protocol Tests + + +def test_plugin_protocol_is_runtime_checkable(): + """Test that Plugin Protocol is runtime checkable with isinstance.""" + + class MyPlugin: + name = "my-plugin" + + def init_plugin(self, agent): + pass + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + +def test_plugin_protocol_sync_implementation(): + """Test Plugin Protocol works with synchronous init_plugin.""" + + class SyncPlugin: + name = "sync-plugin" + + def init_plugin(self, agent): + agent.custom_attribute = "initialized by plugin" + + plugin = SyncPlugin() + mock_agent = unittest.mock.Mock() + + # Verify the plugin matches the protocol + assert isinstance(plugin, Plugin) + assert plugin.name == "sync-plugin" + + # Execute init_plugin synchronously + plugin.init_plugin(mock_agent) + assert mock_agent.custom_attribute == "initialized by plugin" + + +@pytest.mark.asyncio +async def test_plugin_protocol_async_implementation(): + """Test Plugin Protocol works with asynchronous init_plugin.""" + + class AsyncPlugin: + name = "async-plugin" + + async def init_plugin(self, agent): + agent.custom_attribute = "initialized by async plugin" + + plugin = AsyncPlugin() + mock_agent = unittest.mock.Mock() + + # Verify the plugin matches the protocol + assert isinstance(plugin, Plugin) + assert plugin.name == "async-plugin" + + # Execute init_plugin asynchronously + await plugin.init_plugin(mock_agent) + assert mock_agent.custom_attribute == "initialized by async plugin" + + +def test_plugin_protocol_requires_name(): + """Test that Plugin Protocol requires a name property.""" + + class PluginWithoutName: + def init_plugin(self, agent): + pass + + plugin = PluginWithoutName() + # A class without 'name' should not pass isinstance check + assert not isinstance(plugin, Plugin) + + +def test_plugin_protocol_requires_init_plugin_method(): + """Test that Plugin Protocol requires an init_plugin method.""" + + class PluginWithoutInitPlugin: + name = "incomplete-plugin" + + plugin = PluginWithoutInitPlugin() + # A class without 'init_plugin' should not pass isinstance check + assert not isinstance(plugin, Plugin) + + +def test_plugin_protocol_with_class_attribute_name(): + """Test Plugin Protocol works when name is a class attribute.""" + + class PluginWithClassAttribute: + name: str = "class-attr-plugin" + + def init_plugin(self, agent): + pass + + plugin = PluginWithClassAttribute() + assert isinstance(plugin, Plugin) + assert plugin.name == "class-attr-plugin" + + +def test_plugin_protocol_with_property_name(): + """Test Plugin Protocol works when name is a property.""" + + class PluginWithProperty: + @property + def name(self): + return "property-plugin" + + def init_plugin(self, agent): + pass + + plugin = PluginWithProperty() + assert isinstance(plugin, Plugin) + assert plugin.name == "property-plugin" + + +# _PluginRegistry Tests + + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing.""" + return unittest.mock.Mock() + + +@pytest.fixture +def registry(mock_agent): + """Create a fresh _PluginRegistry for each test.""" + return _PluginRegistry(mock_agent) + + +def test_plugin_registry_add_and_init_calls_init_plugin(registry, mock_agent): + """Test adding a plugin calls its init_plugin method.""" + + class TestPlugin: + name = "test-plugin" + + def __init__(self): + self.initialized = False + + def init_plugin(self, agent): + self.initialized = True + agent.plugin_initialized = True + + plugin = TestPlugin() + registry.add_and_init(plugin) + + assert plugin.initialized + assert mock_agent.plugin_initialized + + +def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): + """Test that adding a duplicate plugin raises an error.""" + + class TestPlugin: + name = "test-plugin" + + def init_plugin(self, agent): + pass + + plugin1 = TestPlugin() + plugin2 = TestPlugin() + + registry.add_and_init(plugin1) + + with pytest.raises(ValueError, match="plugin_name= | plugin already registered"): + registry.add_and_init(plugin2) + + +def test_plugin_registry_add_and_init_with_async_plugin(registry, mock_agent): + """Test that add_and_init handles async plugins using run_async.""" + + class AsyncPlugin: + name = "async-plugin" + + def __init__(self): + self.initialized = False + + async def init_plugin(self, agent): + self.initialized = True + agent.async_plugin_initialized = True + + plugin = AsyncPlugin() + registry.add_and_init(plugin) + + assert plugin.initialized + assert mock_agent.async_plugin_initialized From 30e302036a964fc92450c600bec9ba467e3e6b91 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Thu, 19 Feb 2026 13:02:06 -0500 Subject: [PATCH 150/279] feat: add plugins parameter to Agent (#1734) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/agent.py | 14 +++++++ tests/strands/agent/test_agent.py | 70 +++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7350ab7ed..ebead3b7d 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -46,6 +46,8 @@ from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model +from ..plugins import Plugin +from ..plugins.registry import _PluginRegistry from ..session.session_manager import SessionManager from ..telemetry.metrics import EventLoopMetrics from ..telemetry.tracer import get_tracer, serialize @@ -126,6 +128,7 @@ def __init__( name: str | None = None, description: str | None = None, state: AgentState | dict | None = None, + plugins: list[Plugin] | None = None, hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, @@ -176,6 +179,10 @@ def __init__( Defaults to None. state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict. Defaults to an empty AgentState object. + plugins: List of Plugin instances to extend agent functionality. + Plugins are initialized with the agent instance after construction and can register hooks, + modify agent attributes, or perform other setup tasks. + Defaults to None. hooks: hooks to be added to the agent hook registry Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. @@ -265,6 +272,8 @@ def __init__( self.hooks = HookRegistry() + self._plugin_registry = _PluginRegistry(self) + self._interrupt_state = _InterruptState() # Initialize lock for guarding concurrent invocations @@ -311,6 +320,11 @@ def __init__( if hooks: for hook in hooks: self.hooks.add_hook(hook) + + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) @property diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 587735cec..5deeb4f7c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2619,3 +2619,73 @@ def untyped_callback(event): with pytest.raises(ValueError, match="cannot infer event type"): agent.add_hook(untyped_callback) + + +def test_agent_plugins_sync_initialization(): + """Test that plugins with sync init_plugin are initialized correctly.""" + plugin_mock = unittest.mock.Mock() + plugin_mock.name = "test-plugin" + plugin_mock.init_plugin = unittest.mock.Mock() + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin_mock], + ) + + plugin_mock.init_plugin.assert_called_once_with(agent) + + +def test_agent_plugins_async_initialization(): + """Test that plugins with async init_plugin are initialized correctly.""" + plugin_mock = unittest.mock.Mock() + plugin_mock.name = "async-plugin" + plugin_mock.init_plugin = unittest.mock.AsyncMock() + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin_mock], + ) + + plugin_mock.init_plugin.assert_called_once_with(agent) + + +def test_agent_plugins_multiple_in_order(): + """Test that multiple plugins are initialized in order.""" + call_order = [] + + plugin1 = unittest.mock.Mock() + plugin1.name = "plugin1" + plugin1.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) + + plugin2 = unittest.mock.Mock() + plugin2.name = "plugin2" + plugin2.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) + + Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[plugin1, plugin2], + ) + + assert call_order == ["plugin1", "plugin2"] + + +def test_agent_plugins_can_register_hooks(): + """Test that plugins can register hooks during initialization.""" + hook_called = [] + + class TestPlugin: + name = "hook-plugin" + + def init_plugin(self, agent): + def hook_callback(event: BeforeModelCallEvent): + hook_called.append(True) + + agent.add_hook(hook_callback) + + agent = Agent( + model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), + plugins=[TestPlugin()], + ) + + agent("test") + assert len(hook_called) == 1 From 881acc0126d2a3011deaa65a6d9ba55a5f5708b2 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 20 Feb 2026 11:55:56 -0500 Subject: [PATCH 151/279] refactor(plugins): convert Plugin from Protocol to ABC (#1741) --- src/strands/plugins/__init__.py | 2 +- src/strands/plugins/plugin.py | 21 +++++--- src/strands/plugins/registry.py | 2 +- tests/strands/plugins/test_plugins.py | 70 +++++++++++++-------------- 4 files changed, 50 insertions(+), 45 deletions(-) diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 33922e952..9ec9c9357 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -7,7 +7,7 @@ ```python from strands.plugins import Plugin - class LoggingPlugin: + class LoggingPlugin(Plugin): name = "logging" def init_plugin(self, agent: Agent) -> None: diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index b6a8fd1d9..80707616a 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -1,19 +1,19 @@ -"""Plugin protocol for extending agent functionality. +"""Plugin base class for extending agent functionality. -This module defines the Plugin Protocol, which provides a composable way to +This module defines the Plugin base class, which provides a composable way to add behavior changes to agents through a standardized initialization pattern. """ +from abc import ABC, abstractmethod from collections.abc import Awaitable -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING if TYPE_CHECKING: from ..agent import Agent -@runtime_checkable -class Plugin(Protocol): - """Protocol for objects that extend agent functionality. +class Plugin(ABC): + """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. They are initialized with an agent instance and can register hooks, @@ -24,7 +24,7 @@ class Plugin(Protocol): Example: ```python - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: @@ -32,8 +32,13 @@ def init_plugin(self, agent: Agent) -> None: ``` """ - name: str + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + @abstractmethod def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index ffd73c2f4..34a7a6639 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -28,7 +28,7 @@ class _PluginRegistry: ```python registry = _PluginRegistry(agent) - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 90f6a2545..9274d2f12 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -10,10 +10,10 @@ # Plugin Protocol Tests -def test_plugin_protocol_is_runtime_checkable(): - """Test that Plugin Protocol is runtime checkable with isinstance.""" +def test_plugin_class_requires_inheritance(): + """Test that Plugin class requires inheritance.""" - class MyPlugin: + class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent): @@ -23,10 +23,10 @@ def init_plugin(self, agent): assert isinstance(plugin, Plugin) -def test_plugin_protocol_sync_implementation(): - """Test Plugin Protocol works with synchronous init_plugin.""" +def test_plugin_class_sync_implementation(): + """Test Plugin class works with synchronous init_plugin.""" - class SyncPlugin: + class SyncPlugin(Plugin): name = "sync-plugin" def init_plugin(self, agent): @@ -35,7 +35,7 @@ def init_plugin(self, agent): plugin = SyncPlugin() mock_agent = unittest.mock.Mock() - # Verify the plugin matches the protocol + # Verify the plugin is an instance of Plugin assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,10 +45,10 @@ def init_plugin(self, agent): @pytest.mark.asyncio -async def test_plugin_protocol_async_implementation(): - """Test Plugin Protocol works with asynchronous init_plugin.""" +async def test_plugin_class_async_implementation(): + """Test Plugin class works with asynchronous init_plugin.""" - class AsyncPlugin: + class AsyncPlugin(Plugin): name = "async-plugin" async def init_plugin(self, agent): @@ -57,7 +57,7 @@ async def init_plugin(self, agent): plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() - # Verify the plugin matches the protocol + # Verify the plugin is an instance of Plugin assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -66,33 +66,33 @@ async def init_plugin(self, agent): assert mock_agent.custom_attribute == "initialized by async plugin" -def test_plugin_protocol_requires_name(): - """Test that Plugin Protocol requires a name property.""" +def test_plugin_class_requires_name(): + """Test that Plugin class requires a name property.""" - class PluginWithoutName: - def init_plugin(self, agent): - pass + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(Plugin): + def init_plugin(self, agent): + pass + + PluginWithoutName() - plugin = PluginWithoutName() - # A class without 'name' should not pass isinstance check - assert not isinstance(plugin, Plugin) +def test_plugin_class_requires_init_plugin_method(): + """Test that Plugin class requires an init_plugin method.""" -def test_plugin_protocol_requires_init_plugin_method(): - """Test that Plugin Protocol requires an init_plugin method.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): - class PluginWithoutInitPlugin: - name = "incomplete-plugin" + class PluginWithoutInitPlugin(Plugin): + name = "incomplete-plugin" - plugin = PluginWithoutInitPlugin() - # A class without 'init_plugin' should not pass isinstance check - assert not isinstance(plugin, Plugin) + PluginWithoutInitPlugin() -def test_plugin_protocol_with_class_attribute_name(): - """Test Plugin Protocol works when name is a class attribute.""" +def test_plugin_class_with_class_attribute_name(): + """Test Plugin class works when name is a class attribute.""" - class PluginWithClassAttribute: + class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" def init_plugin(self, agent): @@ -103,10 +103,10 @@ def init_plugin(self, agent): assert plugin.name == "class-attr-plugin" -def test_plugin_protocol_with_property_name(): - """Test Plugin Protocol works when name is a property.""" +def test_plugin_class_with_property_name(): + """Test Plugin class works when name is a property.""" - class PluginWithProperty: + class PluginWithProperty(Plugin): @property def name(self): return "property-plugin" @@ -137,7 +137,7 @@ def registry(mock_agent): def test_plugin_registry_add_and_init_calls_init_plugin(registry, mock_agent): """Test adding a plugin calls its init_plugin method.""" - class TestPlugin: + class TestPlugin(Plugin): name = "test-plugin" def __init__(self): @@ -157,7 +157,7 @@ def init_plugin(self, agent): def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): """Test that adding a duplicate plugin raises an error.""" - class TestPlugin: + class TestPlugin(Plugin): name = "test-plugin" def init_plugin(self, agent): @@ -175,7 +175,7 @@ def init_plugin(self, agent): def test_plugin_registry_add_and_init_with_async_plugin(registry, mock_agent): """Test that add_and_init handles async plugins using run_async.""" - class AsyncPlugin: + class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): From d66a54c2bcbe713422622a49214aac6cc0910ff2 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 20 Feb 2026 12:45:11 -0500 Subject: [PATCH 152/279] feat(steering): migrate SteeringHandler from HookProvider to Plugin (#1738) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- AGENTS.md | 2 +- src/strands/experimental/steering/__init__.py | 2 +- .../experimental/steering/core/handler.py | 26 ++++--- .../steering/core/test_handler.py | 75 +++++++++++-------- tests/strands/hooks/test_registry.py | 1 + tests/strands/plugins/test_plugins.py | 2 +- tests_integ/steering/test_model_steering.py | 10 +-- tests_integ/steering/test_tool_steering.py | 4 +- 8 files changed, 70 insertions(+), 52 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 6cd2155c1..6a5765a94 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -127,7 +127,7 @@ strands-agents/ │ │ └── registry.py # Hook registration │ │ │ ├── plugins/ # Plugin system -│ │ ├── plugin.py # Plugin Protocol definition +│ │ ├── plugin.py # Plugin definition │ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index be04a9ddb..c928d0c63 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -13,7 +13,7 @@ Usage: handler = LLMSteeringHandler(system_prompt="...") - agent = Agent(tools=[...], hooks=[handler]) + agent = Agent(tools=[...], plugins=[handler]) """ # Core primitives diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 403a73414..3b869c0eb 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -35,11 +35,10 @@ """ import logging -from abc import ABC from typing import TYPE_CHECKING, Any from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....hooks.registry import HookProvider, HookRegistry +from ....plugins.plugin import Plugin from ....types.content import Message from ....types.streaming import StopReason from ....types.tools import ToolUse @@ -52,20 +51,21 @@ logger = logging.getLogger(__name__) -class SteeringHandler(HookProvider, ABC): +class SteeringHandler(Plugin): """Base class for steering handlers that provide contextual guidance to agents. Steering handlers maintain local context and register hook callbacks to populate context data as needed for guidance decisions. """ + name: str = "steering" + def __init__(self, context_providers: list[SteeringContextProvider] | None = None): """Initialize the steering handler. Args: context_providers: List of context providers for context updates """ - super().__init__() self.steering_context = SteeringContext() self._context_callbacks = [] @@ -75,19 +75,23 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - """Register hooks for steering guidance and context updates.""" + def init_plugin(self, agent: "Agent") -> None: + """Initialize the steering handler with an agent. + + Registers hook callbacks for steering guidance and context updates. + + Args: + agent: The agent instance to attach steering to. + """ # Register context update callbacks for callback in self._context_callbacks: - registry.add_callback( - callback.event_type, lambda event, callback=callback: callback(event, self.steering_context) - ) + agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) # Register tool steering guidance - registry.add_callback(BeforeToolCallEvent, self._provide_tool_steering_guidance) + agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent) # Register model steering guidance - registry.add_callback(AfterModelCallEvent, self._provide_model_steering_guidance) + agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent) async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 04d3a56c1..447780939 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -8,7 +8,7 @@ from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from strands.hooks.registry import HookRegistry +from strands.plugins import Plugin class TestSteeringHandler(SteeringHandler): @@ -24,16 +24,29 @@ def test_steering_handler_initialization(): assert handler is not None -def test_register_hooks(): - """Test hook registration.""" +def test_steering_handler_has_name_attribute(): + """Test SteeringHandler has name attribute for Plugin.""" handler = TestSteeringHandler() - registry = Mock(spec=HookRegistry) + assert hasattr(handler, "name") + assert handler.name == "steering" - handler.register_hooks(registry) + +def test_steering_handler_is_plugin(): + """Test SteeringHandler implements Plugin.""" + handler = TestSteeringHandler() + assert isinstance(handler, Plugin) + + +def test_init_plugin(): + """Test init_plugin registers hooks on agent.""" + handler = TestSteeringHandler() + agent = Mock() + + handler.init_plugin(agent) # Verify hooks were registered (tool and model steering hooks) - assert registry.add_callback.call_count >= 2 - registry.add_callback.assert_any_call(BeforeToolCallEvent, handler._provide_tool_steering_guidance) + assert agent.add_hook.call_count >= 2 + agent.add_hook.assert_any_call(handler._provide_tool_steering_guidance, BeforeToolCallEvent) def test_steering_context_initialization(): @@ -155,24 +168,24 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): await handler._provide_tool_steering_guidance(event) -def test_register_steering_hooks_override(): - """Test that _register_steering_hooks can be overridden.""" +def test_init_plugin_override(): + """Test that init_plugin can be overridden.""" class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Custom") - def register_hooks(self, registry, **kwargs): + def init_plugin(self, agent): # Custom hook registration - don't call parent pass handler = CustomHandler() - registry = Mock(spec=HookRegistry) + agent = Mock() - handler.register_hooks(registry) + handler.init_plugin(agent) # Should not register any hooks - assert registry.add_callback.call_count == 0 + assert agent.add_hook.call_count == 0 # Integration tests with context providers @@ -208,16 +221,16 @@ def test_handler_registers_context_provider_hooks(): """Test that handler registers hooks from context callbacks.""" mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) - registry = Mock(spec=HookRegistry) + agent = Mock() - handler.register_hooks(registry) + handler.init_plugin(agent) # Should register hooks for context callback and steering guidance - assert registry.add_callback.call_count >= 2 + assert agent.add_hook.call_count >= 2 # Check that BeforeToolCallEvent was registered - call_args = [call[0] for call in registry.add_callback.call_args_list] - event_types = [args[0] for args in call_args] + call_args = [call[0] for call in agent.add_hook.call_args_list] + event_types = [args[1] for args in call_args] assert BeforeToolCallEvent in event_types @@ -226,15 +239,15 @@ def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) - registry = Mock(spec=HookRegistry) + agent = Mock() - handler.register_hooks(registry) + handler.init_plugin(agent) # Get the registered callback for BeforeToolCallEvent before_callback = None - for call in registry.add_callback.call_args_list: - if call[0][0] == BeforeToolCallEvent: - before_callback = call[0][1] + for call in agent.add_hook.call_args_list: + if call[0][1] == BeforeToolCallEvent: + before_callback = call[0][0] break assert before_callback is not None @@ -256,13 +269,13 @@ def test_multiple_context_callbacks_registered(): callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) - registry = Mock(spec=HookRegistry) + agent = Mock() - handler.register_hooks(registry) + handler.init_plugin(agent) # Should register one callback for each context provider plus tool and model steering guidance expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) - assert registry.add_callback.call_count >= expected_calls + assert agent.add_hook.call_count >= expected_calls def test_handler_initialization_with_callbacks(): @@ -472,12 +485,12 @@ async def test_default_steer_after_model_returns_proceed(): assert "Default implementation" in result.reason -def test_register_hooks_registers_model_steering(): - """Test that register_hooks registers model steering callback.""" +def test_init_plugin_registers_model_steering(): + """Test that init_plugin registers model steering callback.""" handler = TestSteeringHandler() - registry = Mock(spec=HookRegistry) + agent = Mock() - handler.register_hooks(registry) + handler.init_plugin(agent) # Verify model steering hook was registered - registry.add_callback.assert_any_call(AfterModelCallEvent, handler._provide_model_steering_guidance) + agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 79829b92b..5b0f3c574 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -164,6 +164,7 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] + # ========== Tests for union type support ========== diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 9274d2f12..7d0f49dc9 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -7,7 +7,7 @@ from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry -# Plugin Protocol Tests +# Plugin Tests def test_plugin_class_requires_inheritance(): diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py index dccb0fa3a..d1948586a 100644 --- a/tests_integ/steering/test_model_steering.py +++ b/tests_integ/steering/test_model_steering.py @@ -39,7 +39,7 @@ async def steer_after_model( def test_model_steering_proceeds_without_intervention(): """Test that model steering can accept responses without modification.""" handler = SimpleModelSteeringHandler(should_guide=False) - agent = Agent(hooks=[handler]) + agent = Agent(plugins=[handler]) response = agent("What is 2+2?") @@ -54,7 +54,7 @@ def test_model_steering_proceeds_without_intervention(): def test_model_steering_guide_triggers_retry(): """Test that Guide action triggers model retry.""" handler = SimpleModelSteeringHandler(should_guide=True, guidance_message="Please provide a more detailed response.") - agent = Agent(hooks=[handler]) + agent = Agent(plugins=[handler]) response = agent("What is the capital of France?") @@ -85,7 +85,7 @@ async def steer_after_model( return Proceed(reason="Response is good now") handler = SpecificGuidanceHandler() - agent = Agent(hooks=[handler]) + agent = Agent(plugins=[handler]) response = agent("What is the capital of France?") @@ -122,7 +122,7 @@ async def steer_after_model( return Proceed(reason="Response is good now") handler = MultiRetryHandler() - agent = Agent(hooks=[handler]) + agent = Agent(plugins=[handler]) response = agent("Explain machine learning.") @@ -195,7 +195,7 @@ async def steer_after_model( return Proceed(reason="Guidance was provided") handler = ForceToolUsageHandler(required_tool="log_activity") - agent = Agent(tools=[log_activity], hooks=[handler]) + agent = Agent(tools=[log_activity], plugins=[handler]) # Ask a question that clearly doesn't need the logging tool response = agent("What is 2 + 2?") diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py index 5036c759c..e441e71da 100644 --- a/tests_integ/steering/test_tool_steering.py +++ b/tests_integ/steering/test_tool_steering.py @@ -87,7 +87,7 @@ def test_agent_with_tool_steering_e2e(): context_providers=[], # Disable ledger to avoid confusing context ) - agent = Agent(tools=[send_email, send_notification], hooks=[handler]) + agent = Agent(tools=[send_email, send_notification], plugins=[handler]) # This should trigger steering guidance to use send_notification instead response = agent("Send an email to john@example.com saying hello") @@ -132,7 +132,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Ledger verified") handler = LedgerCheckingHandler() - agent = Agent(tools=[send_notification], hooks=[handler]) + agent = Agent(tools=[send_notification], plugins=[handler]) agent("Send a notification to alice saying test message") From 42e18b8c0f0dbe34e6b202255d0f86831fa05746 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Wed, 25 Feb 2026 10:12:47 -0800 Subject: [PATCH 153/279] chore: switch to Sonnet 4.6 for Anthropic provider integ tests (#1754) --- tests_integ/models/providers.py | 2 +- tests_integ/models/test_model_anthropic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 57614b97f..ab8551391 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -66,7 +66,7 @@ def __init__(self): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-7-sonnet-20250219", + model_id="claude-sonnet-4-6", max_tokens=512, ), ) diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 9a0d19dff..864360139 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -28,7 +28,7 @@ def model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-7-sonnet-20250219", + model_id="claude-sonnet-4-6", max_tokens=512, ) From 37938da7ee8213d63c754a7b846cfaabd0ff619a Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 25 Feb 2026 13:51:32 -0500 Subject: [PATCH 154/279] fix: rename init_plugin to init_agent (#1765) --- .../experimental/steering/core/handler.py | 2 +- src/strands/plugins/__init__.py | 2 +- src/strands/plugins/plugin.py | 12 +++--- src/strands/plugins/registry.py | 12 +++--- tests/strands/agent/test_agent.py | 18 ++++----- .../steering/core/test_handler.py | 26 ++++++------- tests/strands/plugins/test_plugins.py | 38 +++++++++---------- 7 files changed, 55 insertions(+), 55 deletions(-) diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 3b869c0eb..9dac9ba74 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -75,7 +75,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) - def init_plugin(self, agent: "Agent") -> None: + def init_agent(self, agent: "Agent") -> None: """Initialize the steering handler with an agent. Registers hook callbacks for steering guidance and context updates. diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 9ec9c9357..aa1491545 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -10,7 +10,7 @@ class LoggingPlugin(Plugin): name = "logging" - def init_plugin(self, agent: Agent) -> None: + def init_agent(self, agent: Agent) -> None: agent.add_hook(self.on_model_call, BeforeModelCallEvent) def on_model_call(self, event: BeforeModelCallEvent) -> None: diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 80707616a..e9f35f112 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -16,8 +16,8 @@ class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They are initialized with an agent instance and can register hooks, - modify agent attributes, or perform other setup tasks. + They can register hooks, modify agent attributes, or perform other + setup tasks on an agent instance. Attributes: name: A stable string identifier for the plugin @@ -27,7 +27,7 @@ class Plugin(ABC): class MyPlugin(Plugin): name = "my-plugin" - def init_plugin(self, agent: Agent) -> None: + def init_agent(self, agent: Agent) -> None: agent.add_hook(self.on_model_call, BeforeModelCallEvent) ``` """ @@ -39,10 +39,10 @@ def name(self) -> str: ... @abstractmethod - def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: - """Initialize the plugin with an agent instance. + def init_agent(self, agent: "Agent") -> None | Awaitable[None]: + """Initialize the agent instance. Args: - agent: The agent instance to extend. + agent: The agent instance to initialize. """ ... diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index 34a7a6639..3b8a0a45f 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -31,7 +31,7 @@ class _PluginRegistry: class MyPlugin(Plugin): name = "my-plugin" - def init_plugin(self, agent: Agent) -> None: + def init_agent(self, agent: Agent) -> None: pass plugin = MyPlugin() @@ -51,8 +51,8 @@ def __init__(self, agent: "Agent") -> None: def add_and_init(self, plugin: Plugin) -> None: """Add and initialize a plugin with the agent. - This method registers the plugin and calls its init_plugin method. - Handles both sync and async init_plugin implementations automatically. + This method registers the plugin and calls its init_agent method. + Handles both sync and async init_agent implementations automatically. Args: plugin: The plugin to add and initialize. @@ -66,8 +66,8 @@ def add_and_init(self, plugin: Plugin) -> None: logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) self._plugins[plugin.name] = plugin - if inspect.iscoroutinefunction(plugin.init_plugin): - async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_plugin) + if inspect.iscoroutinefunction(plugin.init_agent): + async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_agent) run_async(lambda: async_plugin_init(self._agent)) else: - plugin.init_plugin(self._agent) + plugin.init_agent(self._agent) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5deeb4f7c..55de68ff1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -2622,31 +2622,31 @@ def untyped_callback(event): def test_agent_plugins_sync_initialization(): - """Test that plugins with sync init_plugin are initialized correctly.""" + """Test that plugins with sync init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "test-plugin" - plugin_mock.init_plugin = unittest.mock.Mock() + plugin_mock.init_agent = unittest.mock.Mock() agent = Agent( model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), plugins=[plugin_mock], ) - plugin_mock.init_plugin.assert_called_once_with(agent) + plugin_mock.init_agent.assert_called_once_with(agent) def test_agent_plugins_async_initialization(): - """Test that plugins with async init_plugin are initialized correctly.""" + """Test that plugins with async init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "async-plugin" - plugin_mock.init_plugin = unittest.mock.AsyncMock() + plugin_mock.init_agent = unittest.mock.AsyncMock() agent = Agent( model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), plugins=[plugin_mock], ) - plugin_mock.init_plugin.assert_called_once_with(agent) + plugin_mock.init_agent.assert_called_once_with(agent) def test_agent_plugins_multiple_in_order(): @@ -2655,11 +2655,11 @@ def test_agent_plugins_multiple_in_order(): plugin1 = unittest.mock.Mock() plugin1.name = "plugin1" - plugin1.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) + plugin1.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) plugin2 = unittest.mock.Mock() plugin2.name = "plugin2" - plugin2.init_plugin = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) + plugin2.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) Agent( model=MockedModelProvider([{"role": "assistant", "content": [{"text": "response"}]}]), @@ -2676,7 +2676,7 @@ def test_agent_plugins_can_register_hooks(): class TestPlugin: name = "hook-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): def hook_callback(event: BeforeModelCallEvent): hook_called.append(True) diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 447780939..90064ea98 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -37,12 +37,12 @@ def test_steering_handler_is_plugin(): assert isinstance(handler, Plugin) -def test_init_plugin(): - """Test init_plugin registers hooks on agent.""" +def test_init_agent(): + """Test init_agent registers hooks on agent.""" handler = TestSteeringHandler() agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Verify hooks were registered (tool and model steering hooks) assert agent.add_hook.call_count >= 2 @@ -168,21 +168,21 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): await handler._provide_tool_steering_guidance(event) -def test_init_plugin_override(): - """Test that init_plugin can be overridden.""" +def test_init_agent_override(): + """Test that init_agent can be overridden.""" class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): return Proceed(reason="Custom") - def init_plugin(self, agent): + def init_agent(self, agent): # Custom hook registration - don't call parent pass handler = CustomHandler() agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Should not register any hooks assert agent.add_hook.call_count == 0 @@ -223,7 +223,7 @@ def test_handler_registers_context_provider_hooks(): handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Should register hooks for context callback and steering guidance assert agent.add_hook.call_count >= 2 @@ -241,7 +241,7 @@ def test_context_callbacks_receive_steering_context(): handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Get the registered callback for BeforeToolCallEvent before_callback = None @@ -271,7 +271,7 @@ def test_multiple_context_callbacks_registered(): handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Should register one callback for each context provider plus tool and model steering guidance expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) @@ -485,12 +485,12 @@ async def test_default_steer_after_model_returns_proceed(): assert "Default implementation" in result.reason -def test_init_plugin_registers_model_steering(): - """Test that init_plugin registers model steering callback.""" +def test_init_agent_registers_model_steering(): + """Test that init_agent registers model steering callback.""" handler = TestSteeringHandler() agent = Mock() - handler.init_plugin(agent) + handler.init_agent(agent) # Verify model steering hook was registered agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 7d0f49dc9..c16cfcf7a 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -16,7 +16,7 @@ def test_plugin_class_requires_inheritance(): class MyPlugin(Plugin): name = "my-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): pass plugin = MyPlugin() @@ -24,12 +24,12 @@ def init_plugin(self, agent): def test_plugin_class_sync_implementation(): - """Test Plugin class works with synchronous init_plugin.""" + """Test Plugin class works with synchronous init_agent.""" class SyncPlugin(Plugin): name = "sync-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() @@ -39,19 +39,19 @@ def init_plugin(self, agent): assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" - # Execute init_plugin synchronously - plugin.init_plugin(mock_agent) + # Execute init_agent synchronously + plugin.init_agent(mock_agent) assert mock_agent.custom_attribute == "initialized by plugin" @pytest.mark.asyncio async def test_plugin_class_async_implementation(): - """Test Plugin class works with asynchronous init_plugin.""" + """Test Plugin class works with asynchronous init_agent.""" class AsyncPlugin(Plugin): name = "async-plugin" - async def init_plugin(self, agent): + async def init_agent(self, agent): agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() @@ -61,8 +61,8 @@ async def init_plugin(self, agent): assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" - # Execute init_plugin asynchronously - await plugin.init_plugin(mock_agent) + # Execute init_agent asynchronously + await plugin.init_agent(mock_agent) assert mock_agent.custom_attribute == "initialized by async plugin" @@ -72,14 +72,14 @@ def test_plugin_class_requires_name(): with pytest.raises(TypeError, match="Can't instantiate abstract class"): class PluginWithoutName(Plugin): - def init_plugin(self, agent): + def init_agent(self, agent): pass PluginWithoutName() -def test_plugin_class_requires_init_plugin_method(): - """Test that Plugin class requires an init_plugin method.""" +def test_plugin_class_requires_init_agent_method(): + """Test that Plugin class requires an init_agent method.""" with pytest.raises(TypeError, match="Can't instantiate abstract class"): @@ -95,7 +95,7 @@ def test_plugin_class_with_class_attribute_name(): class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): pass plugin = PluginWithClassAttribute() @@ -111,7 +111,7 @@ class PluginWithProperty(Plugin): def name(self): return "property-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): pass plugin = PluginWithProperty() @@ -134,8 +134,8 @@ def registry(mock_agent): return _PluginRegistry(mock_agent) -def test_plugin_registry_add_and_init_calls_init_plugin(registry, mock_agent): - """Test adding a plugin calls its init_plugin method.""" +def test_plugin_registry_add_and_init_calls_init_agent(registry, mock_agent): + """Test adding a plugin calls its init_agent method.""" class TestPlugin(Plugin): name = "test-plugin" @@ -143,7 +143,7 @@ class TestPlugin(Plugin): def __init__(self): self.initialized = False - def init_plugin(self, agent): + def init_agent(self, agent): self.initialized = True agent.plugin_initialized = True @@ -160,7 +160,7 @@ def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): class TestPlugin(Plugin): name = "test-plugin" - def init_plugin(self, agent): + def init_agent(self, agent): pass plugin1 = TestPlugin() @@ -181,7 +181,7 @@ class AsyncPlugin(Plugin): def __init__(self): self.initialized = False - async def init_plugin(self, agent): + async def init_agent(self, agent): self.initialized = True agent.async_plugin_initialized = True From 1df2438ab356f0cd98e838f932286ce62f6da08d Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Fri, 27 Feb 2026 08:51:27 -0800 Subject: [PATCH 155/279] test: pin virtualenv to <21 for hatch bug (#1771) --- .github/workflows/integration-test.yml | 4 +++- .github/workflows/pypi-publish-on-release.yml | 4 +++- .github/workflows/test-lint.yml | 8 ++++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index a40eb0f45..e7cdbe131 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -48,8 +48,10 @@ jobs: with: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run integration tests env: AWS_REGION: us-east-1 diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index bf2c9f21d..7c96a9789 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -34,9 +34,11 @@ jobs: python-version: '3.10' - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | python -m pip install --upgrade pip - pip install hatch twine + pip install hatch twine 'virtualenv<21' - name: Validate version run: | diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 89cc459de..5f5aa6fcd 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -83,8 +83,10 @@ jobs: # Windows typically has audio libraries available by default echo "Windows audio dependencies handled by PyAudio wheels" - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run Unit tests id: tests run: hatch test tests --cover @@ -118,8 +120,10 @@ jobs: sudo apt-get install -y portaudio19-dev libasound2-dev - name: Install dependencies + # Pin virtualenv until hatch is fixed. + # See https://github.com/pypa/hatch/issues/2193 run: | - pip install --no-cache-dir hatch + pip install --no-cache-dir hatch 'virtualenv<21' - name: Run lint id: lint From 2c83216d05639facfcbb9ddcf07dce2c5f15accf Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 27 Feb 2026 15:31:22 -0500 Subject: [PATCH 156/279] fix(telemetry): added latest semantic conventions as span attributes for langfuse (#1768) --- src/strands/telemetry/tracer.py | 55 ++-- tests/strands/telemetry/test_tracer.py | 352 +++++++++++++++++-------- 2 files changed, 274 insertions(+), 133 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 6ab33301a..80fb86c40 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -110,6 +110,17 @@ def _parse_semconv_opt_in(self) -> set[str]: opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") return {value.strip() for value in opt_in_env.split(",")} + @property + def is_langfuse(self) -> bool: + """Check if Langfuse is configured as the OTLP endpoint. + + Returns: + True if Langfuse is the OTLP endpoint, False otherwise. + """ + return "langfuse" in os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") or "langfuse" in os.getenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "" + ) + def _start_span( self, span_name: str, @@ -142,23 +153,10 @@ def _start_span( # Add all provided attributes if attributes: - self._set_attributes(span, attributes) + span.set_attributes(attributes) return span - def _set_attributes(self, span: Span, attributes: dict[str, AttributeValue]) -> None: - """Set attributes on a span, handling different value types appropriately. - - Args: - span: The span to set attributes on - attributes: Dictionary of attributes to set - """ - if not span: - return - - for key, value in attributes.items(): - span.set_attribute(key, value) - def _add_optional_usage_and_metrics_attributes( self, attributes: dict[str, AttributeValue], usage: Usage, metrics: Metrics ) -> None: @@ -203,7 +201,7 @@ def _end_span( # Add any additional attributes if attributes: - self._set_attributes(span, attributes) + span.set_attributes(attributes) # Handle error if present if error: @@ -236,17 +234,24 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Excepti error = exception or Exception(error_message) self._end_span(span, error=error) - def _add_event(self, span: Span | None, event_name: str, event_attributes: Attributes) -> None: + def _add_event( + self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False + ) -> None: """Add an event with attributes to a span. Args: span: The span to add the event to event_name: Name of the event event_attributes: Dictionary of attributes to set on the event + to_span_attributes: Add the attributes to span attributes """ if not span: return + # Add to span attribute since some backend can't read the events + if to_span_attributes and event_attributes: + span.set_attributes(event_attributes) + span.add_event(event_name, attributes=event_attributes) def _get_event_name_for_message(self, message: Message) -> str: @@ -358,6 +363,7 @@ def end_model_invoke_span( ] ), }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -366,7 +372,7 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - self._set_attributes(span, attributes) + span.set_attributes(attributes) def start_tool_call_span( self, @@ -423,6 +429,7 @@ def start_tool_call_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -476,6 +483,7 @@ def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -572,6 +580,7 @@ def end_event_loop_cycle_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) @@ -666,6 +675,7 @@ def end_agent_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -675,9 +685,7 @@ def end_agent_span( ) if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): - if "langfuse" in os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") or "langfuse" in os.getenv( - "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "" - ): + if self.is_langfuse: attributes.update({"langfuse.observation.type": "span"}) accumulated_usage = response.metrics.accumulated_usage attributes.update( @@ -736,6 +744,7 @@ def start_multiagent_span( span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize([{"role": "user", "parts": parts}])}, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -767,6 +776,7 @@ def end_swarm_span( ] ) }, + to_span_attributes=self.is_langfuse, ) else: self._add_event( @@ -816,7 +826,10 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: {"role": message["role"], "parts": self._map_content_blocks_to_otel_parts(message["content"])} ) self._add_event( - span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} + span, + "gen_ai.client.inference.operation.details", + {"gen_ai.input.messages": serialize(input_messages)}, + to_span_attributes=self.is_langfuse, ) else: for message in messages: diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 6ea605083..da7f010e2 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -79,22 +79,11 @@ def test_start_span(mock_tracer): span = tracer._start_span("test_span", attributes={"key": "value"}) mock_tracer.start_span.assert_called_once_with(name="test_span", context=None, kind=SpanKind.INTERNAL) - mock_span.set_attribute.assert_any_call("key", "value") + # Check that set_attributes was called with the provided attributes + mock_span.set_attributes.assert_called_once_with({"key": "value"}) assert span is not None -def test_set_attributes(mock_span): - """Test setting attributes on a span.""" - tracer = Tracer() - attributes = {"str_attr": "value", "int_attr": 123, "bool_attr": True} - - tracer._set_attributes(mock_span, attributes) - - # Check that set_attribute was called for each attribute - calls = [mock.call(k, v) for k, v in attributes.items()] - mock_span.set_attribute.assert_has_calls(calls, any_order=True) - - def test_end_span_no_span(): """Test ending a span when span is None.""" tracer = Tracer() @@ -109,7 +98,8 @@ def test_end_span(mock_span): tracer._end_span(mock_span, attributes) - mock_span.set_attribute.assert_any_call("key", "value") + # Check that set_attributes was called with the provided attributes + mock_span.set_attributes.assert_called_once_with({"key": "value"}) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -158,11 +148,14 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_key", "custom_value") - mock_span.set_attribute.assert_any_call("user_id", "12345") + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} ) @@ -195,9 +188,13 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -235,13 +232,15 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -260,13 +259,15 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 10) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 20) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -299,12 +300,15 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") - mock_span.set_attribute.assert_any_call("session_id", "abc123") - mock_span.set_attribute.assert_any_call("environment", "production") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -327,10 +331,13 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -370,11 +377,14 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") - mock_span.set_attribute.assert_any_call("workflow_id", "wf-789") - mock_span.set_attribute.assert_any_call("priority", "high") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -394,9 +404,12 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -447,9 +460,12 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -512,10 +528,13 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") - mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + }) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -529,7 +548,7 @@ def test_end_tool_call_span(mock_span): tracer.end_tool_call_span(mock_span, tool_result) - mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "success"}) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(tool_result.get("content")), "id": ""}, @@ -546,7 +565,7 @@ def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): tracer.end_tool_call_span(mock_span, tool_result) - mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "success"}) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -589,9 +608,12 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123") - mock_span.set_attribute.assert_any_call("request_id", "req-456") - mock_span.set_attribute.assert_any_call("trace_level", "debug") + + mock_span.set_attributes.assert_called_once_with({ + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + }) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -615,7 +637,8 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123") + + mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -707,10 +730,15 @@ def test_start_agent_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_attr", "value") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -740,10 +768,15 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") - mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") - mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) - mock_span.set_attribute.assert_any_call("custom_attr", "value") + + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + }) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -770,13 +803,17 @@ def test_end_agent_span(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) mock_span.add_event.assert_any_call( "gen_ai.choice", attributes={"message": "Agent response", "finish_reason": "end_turn"}, @@ -800,13 +837,19 @@ def test_end_agent_span_with_langfuse_observation_type(mock_span, monkeypatch): mock_response.__str__ = mock.MagicMock(return_value="Agent response") tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("langfuse.observation.type", "span") - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + + mock_span.set_attributes.assert_called_once_with( + { + "langfuse.observation.type": "span", + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) mock_span.add_event.assert_any_call( "gen_ai.choice", attributes={"message": "Agent response", "finish_reason": "end_turn"}, @@ -831,13 +874,17 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 0, + "gen_ai.usage.cache_write_input_tokens": 0, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -872,15 +919,17 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 5) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 3) - mock_span.set_attribute.assert_any_call("gen_ai.server.request.duration", 10) - mock_span.set_attribute.assert_any_call("gen_ai.server.time_to_first_token", 5) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + }) def test_end_agent_span_with_cache_metrics(mock_span): @@ -904,13 +953,15 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) - mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) - mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 25) - mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 10) + mock_span.set_attributes.assert_called_once_with({ + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + }) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -1444,3 +1495,80 @@ def test_start_agent_span_includes_tool_definitions_when_enabled(monkeypatch): ] expected_json = serialize(expected_tool_details) assert attributes["gen_ai.tool.definitions"] == expected_json + + +def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): + """Test that end_model_invoke_span adds attributes via set_attributes for Langfuse.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + + tracer = Tracer() + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) + + expected_output = serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Response"}], + "finish_reason": "end_turn", + } + ] + ) + + assert mock_span.set_attributes.call_count == 2 + mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) + mock_span.set_attributes.assert_any_call({ + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + }) + + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={"gen_ai.output.messages": expected_output}, + ) + + +def test_end_model_invoke_span_non_langfuse_no_extra_attributes(mock_span, monkeypatch): + """Test that end_model_invoke_span doesn't add extra attributes for non-Langfuse endpoints.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://api.honeycomb.io") + + tracer = Tracer() + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) + + # Verify that set_attribute was NOT called with gen_ai.output.messages + # (it should only be in the event, not as an attribute) + expected_output = serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Response"}], + "finish_reason": "end_turn", + } + ] + ) + + # Check that gen_ai.output.messages was not set as an attribute + set_attribute_calls = [call[0][0] for call in mock_span.set_attribute.call_args_list] + assert "gen_ai.output.messages" not in set_attribute_calls + + # But verify that add_event was still called + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={"gen_ai.output.messages": expected_output}, + ) From c50457d1396dbeb8268f47d56faa5ad539f98ae7 Mon Sep 17 00:00:00 2001 From: Austin Welch Date: Fri, 27 Feb 2026 16:10:14 -0500 Subject: [PATCH 157/279] fix: preserve guardrail_latest_message wrapping after tool execution (#1658) Co-authored-by: Liz <91279165+lizradway@users.noreply.github.com> --- src/strands/models/bedrock.py | 33 +++- tests/strands/models/test_bedrock.py | 267 +++++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 8 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index db1878108..4a48d7229 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -363,6 +363,23 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}}) logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx) + def _find_last_user_text_message_index(self, messages: Messages) -> int | None: + """Find the index of the last user message containing text or image content. + + This is used for guardrail_latest_message to ensure that guardContent wrapping + targets the correct message even when toolResult messages follow. + + Args: + messages: List of messages to search + + Returns: + Index of the last user message with text/image content, or None if not found + """ + for idx, msg in reversed(list(enumerate(messages))): + if msg["role"] == "user" and any("text" in cb or "image" in cb for cb in msg.get("content", [])): + return idx + return None + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -391,7 +408,12 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: filtered_unknown_members = False dropped_deepseek_reasoning_content = False - guardrail_latest_message = self.config.get("guardrail_latest_message", False) + # Pre-compute the index of the last user message containing text or image content. + # This ensures guardContent wrapping is maintained across tool execution cycles, where + # the final message in the list is a toolResult (role=user) rather than text/image content. + last_user_text_idx = None + if self.config.get("guardrail_latest_message", False): + last_user_text_idx = self._find_last_user_text_message_index(messages) for idx, message in enumerate(messages): cleaned_content: list[dict[str, Any]] = [] @@ -413,13 +435,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: if formatted_content is None: continue - # Wrap text or image content in guardrailContent if this is the last user message - if ( - guardrail_latest_message - and idx == len(messages) - 1 - and message["role"] == "user" - and ("text" in formatted_content or "image" in formatted_content) - ): + # Wrap text or image content in guardContent if this is the last user text/image message + if idx == last_user_text_idx and ("text" in formatted_content or "image" in formatted_content): if "text" in formatted_content: formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}} elif "image" in formatted_content: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 228d6c138..9dae16be7 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2405,6 +2405,183 @@ async def test_format_request_with_guardrail_latest_message(model): assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_after_tool_use(model): + """Test that guardContent wraps the last user text message even when a toolResult follows it.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "what is the standard deduction?"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "knowledge_base", + "input": {"query": "standard deduction"}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tool-1", + "content": [{"text": "The standard deduction for 2024 is $14,600."}], + "status": "success", + } + } + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert len(formatted_messages) == 5 + + # Earlier user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Last user message with text content should be wrapped, even though a toolResult comes after + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "what is the standard deduction?" + + # toolResult-only user message should NOT be wrapped + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_wraps_final_user_text(model): + """Test that guardContent wraps the last user message when it contains text content.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "Tell me about taxes"}]}, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Tell me about taxes" + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_sequential_tool_calls(model): + """Test guardContent with multiple tool calls in sequence (no new user input between).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the first user text message, not the toolResults + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "First question" + + # toolResults should not be wrapped + assert "toolResult" in formatted_messages[2]["content"][0] + assert "guardContent" not in formatted_messages[2]["content"][0] + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_image_before_tool_result(model): + """Test guardContent wraps image content even when toolResult follows.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "I see a cat"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Image should be wrapped even though toolResult comes after + assert "guardContent" in formatted_messages[0]["content"][0] + assert "image" in formatted_messages[0]["content"][0]["guardContent"] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_tool_results_same_message(model): + """Test guardContent with multiple parallel tool calls (multiple toolResults in one message).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "Question requiring multiple tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}, + {"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}, + {"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the question + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "Question requiring multiple tools" + + def test_supports_caching_true_for_claude(bedrock_client): """Test that supports_caching returns True for Claude models.""" model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") @@ -2514,3 +2691,93 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client): # New cache point should be at end of last assistant message assert len(cleaned_messages[3]["content"]) == 2 assert "cachePoint" in cleaned_messages[3]["content"][-1] + + +def test_find_last_user_text_message_index_no_user_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None when no user text messages exist.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_only_tool_results(bedrock_client): + """Test _find_last_user_text_message_index returns None when user messages only have toolResult.""" + model = BedrockModel(model_id="test-model") + + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_returns_last_text_message(bedrock_client): + """Test _find_last_user_text_message_index returns the index of the last user message with text.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second question"}]}, + ] + + assert model._find_last_user_text_message_index(messages) == 2 + + +def test_find_last_user_text_message_index_skips_tool_result_messages(bedrock_client): + """Test _find_last_user_text_message_index skips toolResult-only user messages.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_finds_image_message(bedrock_client): + """Test _find_last_user_text_message_index finds user messages with image content.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_empty_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None for empty message list.""" + model = BedrockModel(model_id="test-model") + + assert model._find_last_user_text_message_index([]) is None + + +def test_guardrail_latest_message_disabled_does_not_wrap(model): + """Test that guardContent wrapping is skipped when guardrail_latest_message is not set.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + formatted = request["messages"][0]["content"][0] + + assert "text" in formatted + assert "guardContent" not in formatted From 1a3b4294e99aab74ca177a82509985faac71fa22 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung <51226101+kevmyung@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:19:38 -0700 Subject: [PATCH 158/279] feat(conversation-manager): improve tool result truncation strategy (#1756) --- .../sliding_window_conversation_manager.py | 135 +++++++---- tests/strands/agent/test_agent.py | 10 +- .../agent/test_conversation_manager.py | 211 ++++++++++++++++-- 3 files changed, 290 insertions(+), 66 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 709c876e7..b97de0b06 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -7,12 +7,15 @@ from ...agent.agent import Agent from ...hooks import BeforeModelCallEvent, HookRegistry -from ...types.content import Messages +from ...types.content import ContentBlock, Messages from ...types.exceptions import ContextWindowOverflowException +from ...types.tools import ToolResultContent from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) +_PRESERVE_CHARS = 200 + class SlidingWindowConversationManager(ConversationManager): """Implements a sliding window strategy for managing conversation history. @@ -20,10 +23,21 @@ class SlidingWindowConversationManager(ConversationManager): This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids invalid window states. + When truncation is enabled (the default), large tool results are partially truncated, preserving the first + and last 200 characters, and image blocks inside tool results are replaced with descriptive text placeholders. + Truncation targets the oldest tool results first so the most relevant recent context is preserved as long + as possible. + Supports proactive management during agent loop execution via the per_turn parameter. """ - def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False): + def __init__( + self, + window_size: int = 40, + should_truncate_results: bool = True, + *, + per_turn: bool | int = False, + ): """Initialize the sliding window conversation manager. Args: @@ -44,6 +58,9 @@ def __init__(self, window_size: int = 40, should_truncate_results: bool = True, Raises: ValueError: If per_turn is 0 or a negative integer. """ + if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: + raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") + super().__init__() self.window_size = window_size @@ -157,14 +174,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A messages = agent.messages # Try to truncate the tool result first - last_message_idx_with_tool_results = self._find_last_message_with_tool_results(messages) - if last_message_idx_with_tool_results is not None and self.should_truncate_results: + oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) + if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: logger.debug( - "message_index=<%s> | found message with tool results at index", last_message_idx_with_tool_results + "message_index=<%s> | found message with tool results at index", oldest_message_idx_with_tool_results ) - results_truncated = self._truncate_tool_results(messages, last_message_idx_with_tool_results) + results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) if results_truncated: - logger.debug("message_index=<%s> | tool results truncated", last_message_idx_with_tool_results) + logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) return # Try to trim index id when tool result cannot be truncated anymore @@ -197,10 +214,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A messages[:] = messages[trim_index:] def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: - """Truncate tool results in a message to reduce context size. + """Truncate tool results and replace image blocks in a message to reduce context size. + + For text blocks within tool results, all blocks are partially truncated unless they + have already been truncated. The first and last _PRESERVE_CHARS characters are kept, + and the removed middle is replaced with a notice indicating how many characters were + removed. The tool result status is not changed. - When a message contains tool results that are too large for the model's context window, this function - replaces the content of those tool results with a simple error message. + Image blocks nested inside tool result content are replaced with a short descriptive placeholder. Args: messages: The conversation message history. @@ -212,52 +233,82 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: if msg_idx >= len(messages) or msg_idx < 0: return False + def _image_placeholder(image_block: Any) -> str: + source: Any = image_block.get("source", {}) + media_type = image_block.get("format", "unknown") + data = source.get("bytes", b"") + return f"[image: {media_type}, {len(data) if data else 0} bytes]" + message = messages[msg_idx] changes_made = False - tool_result_too_large_message = "The tool result was too large!" - for i, content in enumerate(message.get("content", [])): - if isinstance(content, dict) and "toolResult" in content: - tool_result_content_text = next( - (item["text"] for item in content["toolResult"]["content"] if "text" in item), - "", - ) - # make the overwriting logic togglable - if ( - message["content"][i]["toolResult"]["status"] == "error" - and tool_result_content_text == tool_result_too_large_message - ): - logger.info("ToolResult has already been updated, skipping overwrite") - return False - # Update status to error with informative message - message["content"][i]["toolResult"]["status"] = "error" - message["content"][i]["toolResult"]["content"] = [{"text": tool_result_too_large_message}] - changes_made = True + new_content: list[ContentBlock] = [] + + for content in message.get("content", []): + if "toolResult" in content: + tool_result: Any = content["toolResult"] + tool_result_items = tool_result.get("content", []) + new_items: list[ToolResultContent] = [] + item_changed = False + + for item in tool_result_items: + # Replace image items nested inside toolResult content + if "image" in item: + new_items.append({"text": _image_placeholder(item["image"])}) + item_changed = True + continue + + # Partially truncate text items that have not already been truncated + if "text" in item: + text = item["text"] + truncation_marker = "... [truncated:" + if truncation_marker not in text and len(text) > 2 * _PRESERVE_CHARS: + prefix = text[:_PRESERVE_CHARS] + suffix = text[-_PRESERVE_CHARS:] + removed = len(text) - 2 * _PRESERVE_CHARS + truncated_text = ( + f"{prefix}...\n\n... [truncated: {removed} chars removed] ...\n\n...{suffix}" + ) + new_items.append({"text": truncated_text}) + item_changed = True + continue + + new_items.append(item) + + if item_changed: + updated_tool_result: Any = { + **{k: v for k, v in tool_result.items() if k != "content"}, + "content": new_items, + } + new_content.append({"toolResult": updated_tool_result}) + changes_made = True + else: + new_content.append(content) + continue + + new_content.append(content) + + if changes_made: + message["content"] = new_content return changes_made - def _find_last_message_with_tool_results(self, messages: Messages) -> int | None: - """Find the index of the last message containing tool results. + def _find_oldest_message_with_tool_results(self, messages: Messages) -> int | None: + """Find the index of the oldest message containing tool results. - This is useful for identifying messages that might need to be truncated to reduce context size. + Iterates from oldest to newest so that truncation targets the least-recent + (and therefore least relevant) tool results first. Args: messages: The conversation message history. Returns: - Index of the last message with tool results, or None if no such message exists. + Index of the oldest message with tool results, or None if no such message exists. """ - # Iterate backwards through all messages (from newest to oldest) - for idx in range(len(messages) - 1, -1, -1): - # Check if this message has any content with toolResult + # Iterate from oldest to newest + for idx in range(len(messages)): current_message = messages[idx] - has_tool_result = False - for content in current_message.get("content", []): if isinstance(content, dict) and "toolResult" in content: - has_tool_result = True - break - - if has_tool_result: - return idx + return idx return None diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 55de68ff1..6d73fc177 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -621,7 +621,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene }, }, }, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "abcdEfghI123"}'}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"random_string": "' + "X" * 500 + '"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "tool_use"}}, ] @@ -635,12 +635,14 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene agent("test message") + large_input = "X" * 500 + truncated_text = large_input[:200] + "...\n\n... [truncated: 100 chars removed] ...\n\n..." + large_input[-200:] expected_messages = [ {"role": "user", "content": [{"text": "test message"}]}, { "role": "assistant", "content": [ - {"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": "abcdEfghI123"}}} + {"toolUse": {"toolUseId": "t1", "name": "tool_decorated", "input": {"random_string": large_input}}} ], }, { @@ -649,8 +651,8 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene { "toolResult": { "toolUseId": "t1", - "status": "error", - "content": [{"text": "The tool result was too large!"}], + "status": "success", + "content": [{"text": truncated_text}], } } ], diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 46876d8e5..fd88954e8 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -177,37 +177,25 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con def test_sliding_window_conversation_manager_with_tool_results_truncated(): + large_text = "A" * 300 + "B" * 300 + "C" * 300 manager = SlidingWindowConversationManager(1) messages = [ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, { "role": "user", - "content": [ - {"toolResult": {"toolUseId": "789", "content": [{"text": "large input"}], "status": "success"}} - ], + "content": [{"toolResult": {"toolUseId": "789", "content": [{"text": large_text}], "status": "success"}}], }, ] test_agent = Agent(messages=messages) manager.reduce_context(test_agent) - expected_messages = [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, - { - "role": "user", - "content": [ - { - "toolResult": { - "toolUseId": "789", - "content": [{"text": "The tool result was too large!"}], - "status": "error", - } - } - ], - }, - ] - - assert messages == expected_messages + result_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] + assert result_text.startswith("A" * 200) + assert result_text.endswith("C" * 200) + assert "... [truncated:" in result_text + # Status must NOT be changed to error + assert messages[1]["content"][0]["toolResult"]["status"] == "success" def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): @@ -267,6 +255,16 @@ def test_per_turn_parameter_validation(): assert SlidingWindowConversationManager(per_turn=3).per_turn == 3 +def test_per_turn_zero_raises_value_error(): + with pytest.raises(ValueError, match="per_turn"): + SlidingWindowConversationManager(per_turn=0) + + +def test_per_turn_negative_raises_value_error(): + with pytest.raises(ValueError, match="per_turn"): + SlidingWindowConversationManager(per_turn=-5) + + def test_conversation_manager_is_hook_provider(): """Test that ConversationManager implements HookProvider protocol.""" manager = NullConversationManager() @@ -420,3 +418,176 @@ def test_per_turn_backward_compatibility(): agent = Agent(model=model, conversation_manager=manager) result = agent("Hello") assert result is not None + + +# ============================================================================== +# Improved Truncation Strategy Tests +# ============================================================================== + + +def test_truncation_targets_oldest_message_first(): + """Oldest message with tool results is truncated before newer ones.""" + large_text = "X" * 20000 + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "tool2", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "2", "content": [{"text": large_text}], "status": "success"}}], + }, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) + + # The oldest tool result (index 1) must be truncated + oldest_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] + assert "... [truncated:" in oldest_text + + # The newest tool result (index 3) must remain untouched after the first reduce_context call + newest_text = messages[3]["content"][0]["toolResult"]["content"][0]["text"] + assert "... [truncated:" not in newest_text + + +def test_large_tool_result_partially_truncated_with_context_preserved(): + """Large tool results are truncated in the middle while the beginning and end are preserved.""" + preserve = 200 # matches _PRESERVE_CHARS + # Build text with distinct prefix, middle, and suffix + prefix_text = "P" * preserve + middle_text = "M" * 500 + suffix_text = "S" * preserve + large_text = prefix_text + middle_text + suffix_text + + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + } + ] + + truncated = manager._truncate_tool_results(messages, 0) + + assert truncated + result_text = messages[0]["content"][0]["toolResult"]["content"][0]["text"] + assert result_text.startswith(prefix_text) + assert result_text.endswith(suffix_text) + assert "... [truncated:" in result_text + removed = len(large_text) - 2 * preserve + assert f"... [truncated: {removed} chars removed] ..." in result_text + + +def test_truncation_does_not_change_status_to_error(): + """Partial truncation must not change the tool result status.""" + large_text = "Z" * 15000 + manager = SlidingWindowConversationManager(window_size=10) + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": large_text}], "status": "success"}}], + } + ] + + manager._truncate_tool_results(messages, 0) + + assert messages[0]["content"][0]["toolResult"]["status"] == "success" + + +def test_image_blocks_inside_tool_result_replaced_with_placeholder(): + """Image blocks nested inside toolResult content are replaced with a text placeholder.""" + manager = SlidingWindowConversationManager(window_size=10) + image_data = b"base64encodeddata" + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [ + {"text": "some text"}, + { + "image": { + "format": "jpeg", + "source": {"bytes": image_data}, + } + }, + ], + "status": "success", + } + } + ], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert changed + tool_result_items = messages[0]["content"][0]["toolResult"]["content"] + assert not any(isinstance(item, dict) and "image" in item for item in tool_result_items) + expected_placeholder = f"[image: jpeg, {len(image_data)} bytes]" + assert any(isinstance(item, dict) and item.get("text") == expected_placeholder for item in tool_result_items) + + +def test_already_truncated_text_not_truncated_again(): + """A text block that already contains the truncation marker is not truncated a second time.""" + manager = SlidingWindowConversationManager(window_size=10) + already_truncated = "A" * 200 + "...\n\n... [truncated: 990 chars removed] ...\n\n..." + "Z" * 200 + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": already_truncated}], + "status": "success", + } + } + ], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == already_truncated + + +def test_short_text_in_tool_result_not_truncated(): + """Text no longer than 2 * _PRESERVE_CHARS must not be modified.""" + manager = SlidingWindowConversationManager(window_size=10) + short_text = "X" * 100 # 100 < 2 * 200 + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": short_text}], "status": "success"}}], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == short_text + + +def test_boundary_text_in_tool_result_not_truncated(): + """Text of exactly 2 * _PRESERVE_CHARS must not be truncated.""" + manager = SlidingWindowConversationManager(window_size=10) + boundary_text = "X" * 400 # exactly 2 * 200 + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": boundary_text}], "status": "success"}}], + } + ] + + changed = manager._truncate_tool_results(messages, 0) + + assert not changed + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text From faad5640540f5a86135f9fb4d82b2d0912219648 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 3 Mar 2026 08:32:11 -0500 Subject: [PATCH 159/279] feat(plugins): improve plugin creation devex with @hook and @tool decorators (#1740) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- AGENTS.md | 6 +- .../experimental/steering/core/handler.py | 15 +- src/strands/hooks/_type_inference.py | 78 +++ src/strands/hooks/registry.py | 79 +-- src/strands/plugins/__init__.py | 18 +- src/strands/plugins/decorator.py | 69 +++ src/strands/plugins/plugin.py | 100 +++- src/strands/plugins/registry.py | 60 +- tests/strands/agent/test_agent.py | 12 +- .../steering/core/test_handler.py | 138 +++-- tests/strands/plugins/test_hook_decorator.py | 243 ++++++++ .../strands/plugins/test_plugin_base_class.py | 553 ++++++++++++++++++ tests/strands/plugins/test_plugins.py | 72 +-- 13 files changed, 1249 insertions(+), 194 deletions(-) create mode 100644 src/strands/hooks/_type_inference.py create mode 100644 src/strands/plugins/decorator.py create mode 100644 tests/strands/plugins/test_hook_decorator.py create mode 100644 tests/strands/plugins/test_plugin_base_class.py diff --git a/AGENTS.md b/AGENTS.md index 6a5765a94..10a66fcd7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,10 +124,12 @@ strands-agents/ │ │ │ ├── hooks/ # Event hooks system │ │ ├── events.py # Hook event definitions -│ │ └── registry.py # Hook registration +│ │ ├── registry.py # Hook registration +│ │ └── _type_inference.py # Event type inference from type hints │ │ │ ├── plugins/ # Plugin system -│ │ ├── plugin.py # Plugin definition +│ │ ├── plugin.py # Plugin base class +│ │ ├── decorator.py # @hook decorator │ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 9dac9ba74..214118d4f 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Any from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....plugins.plugin import Plugin +from ....plugins import Plugin, hook from ....types.content import Message from ....types.streaming import StopReason from ....types.tools import ToolUse @@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non Args: context_providers: List of context providers for context updates """ + super().__init__() self.steering_context = SteeringContext() self._context_callbacks = [] @@ -87,13 +88,8 @@ def init_agent(self, agent: "Agent") -> None: for callback in self._context_callbacks: agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) - # Register tool steering guidance - agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent) - - # Register model steering guidance - agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent) - - async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) @@ -133,7 +129,8 @@ def _handle_tool_steering_action( else: raise ValueError(f"Unknown steering action type for tool call: {action}") - async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: """Provide steering guidance for model response.""" logger.debug("providing model steering guidance") diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py new file mode 100644 index 000000000..fbfb34c04 --- /dev/null +++ b/src/strands/hooks/_type_inference.py @@ -0,0 +1,78 @@ +"""Utility for inferring event types from callback type hints.""" + +import inspect +import logging +import types +from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from .registry import HookCallback, TEvent + +logger = logging.getLogger(__name__) + + +def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]": + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + Args: + callback: The callback function to inspect. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + # Import here to avoid circular dependency + from .registry import BaseHookEvent + + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # Skip 'self' and 'cls' parameters for methods + first_param = params[0] + if first_param.name in ("self", "cls") and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast("type[TEvent]", arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast("type[TEvent]", type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 886ea5644..8b284b0c2 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,24 +9,12 @@ import inspect import logging -import types from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - TypeVar, - Union, - cast, - get_args, - get_origin, - get_type_hints, - runtime_checkable, -) +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException +from ._type_inference import infer_event_types if TYPE_CHECKING: from ..agent import Agent @@ -225,7 +213,7 @@ def multi_handler(event): resolved_event_types = self._validate_event_type_list(event_type) elif event_type is None: # Infer event type(s) from callback type hints - resolved_event_types = self._infer_event_types(callback) + resolved_event_types = infer_event_types(callback) else: # Single event type provided explicitly resolved_event_types = [event_type] @@ -261,67 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ validated.append(et) return validated - def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError( - "callback has no parameters | cannot infer event type, please provide event_type explicitly" - ) - - first_param = params[0] - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) - def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index aa1491545..c4b7c72c7 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,25 +1,13 @@ """Plugin system for extending agent functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through a standardized initialization pattern. - -Example Usage: - ```python - from strands.plugins import Plugin - - class LoggingPlugin(Plugin): - name = "logging" - - def init_agent(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) - - def on_model_call(self, event: BeforeModelCallEvent) -> None: - print(f"Model called for {event.agent.name}") - ``` +extend agent behavior through automatic hook and tool registration. """ +from .decorator import hook from .plugin import Plugin __all__ = [ "Plugin", + "hook", ] diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py new file mode 100644 index 000000000..fc6f75e5b --- /dev/null +++ b/src/strands/plugins/decorator.py @@ -0,0 +1,69 @@ +"""Hook decorator for Plugin methods. + +Marks methods as hook callbacks for automatic registration when the plugin +is attached to an agent. Infers event types from type hints and supports +union types for multiple events. + +Example: + ```python + class MyPlugin(Plugin): + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(event) + ``` +""" + +from collections.abc import Callable +from typing import Generic, cast, overload + +from ..hooks._type_inference import infer_event_types +from ..hooks.registry import HookCallback, TEvent + + +class _WrappedHookCallable(HookCallback, Generic[TEvent]): + """Wrapped version of HookCallback that includes a `_hook_event_types` attribute.""" + + _hook_event_types: list[type[TEvent]] + + +# Handle @hook +@overload +def hook(__func: HookCallback) -> _WrappedHookCallable: ... + + +# Handle @hook() +@overload +def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ... + + +def hook( + func: HookCallback | None = None, +) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]: + """Mark a method as a hook callback for automatic registration. + + Infers event type from the callback's type hint. Supports union types + for multiple events. Can be used as @hook or @hook(). + + Args: + func: The function to decorate. + + Returns: + The decorated function with hook metadata. + + Raises: + ValueError: If event type cannot be inferred from type hints. + """ + + def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]: + # Infer event types from type hints + event_types: list[type[TEvent]] = infer_event_types(f) + + # Store hook metadata on the function + f_wrapped = cast(_WrappedHookCallable, f) + f_wrapped._hook_event_types = event_types + + return f_wrapped + + if func is None: + return decorator + return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index e9f35f112..b670de297 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -1,34 +1,70 @@ """Plugin base class for extending agent functionality. This module defines the Plugin base class, which provides a composable way to -add behavior changes to agents through a standardized initialization pattern. +add behavior changes to agents through automatic hook and tool registration. """ +import logging from abc import ABC, abstractmethod from collections.abc import Awaitable from typing import TYPE_CHECKING +from ..hooks.registry import HookCallback +from ..tools.decorator import DecoratedFunctionTool + if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) + class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They can register hooks, modify agent attributes, or perform other - setup tasks on an agent instance. + They support automatic discovery and registration of methods decorated + with @hook and @tool decorators. Attributes: - name: A stable string identifier for the plugin + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the agent, auto-discovered from @hook decorated methods during __init__ + tools: Tools attached to the agent, auto-discovered from @tool decorated methods during __init__ - Example: + Example using decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + from strands import tool + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @tool + def my_tool(self, param: str) -> str: + '''A tool that does something.''' + return f"Result: {param}" + ``` + + Note: Decorated methods are registered in declaration order, with parent + class methods registered before child class methods. If a child overrides + a parent's decorated method, only the child's version is registered. + + Example with custom initialization: ```python class MyPlugin(Plugin): name = "my-plugin" def init_agent(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + # Custom initialization logic - no super() needed + # Decorated hooks/tools are auto-registered by the plugin registry + agent.add_hook(self.custom_hook) + + def custom_hook(self, event: BeforeModelCallEvent): + print(event) ``` """ @@ -38,11 +74,59 @@ def name(self) -> str: """A stable string identifier for the plugin.""" ... - @abstractmethod + def __init__(self) -> None: + """Initialize the plugin and discover decorated methods. + + Scans the class for methods decorated with @hook and @tool and stores + references for later registration when the plugin is attached to an agent. + """ + self._hooks: list[HookCallback] = [] + self._tools: list[DecoratedFunctionTool] = [] + self._discover_decorated_methods() + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + @property + def tools(self) -> list[DecoratedFunctionTool]: + """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" + return self._tools + + def _discover_decorated_methods(self) -> None: + """Scan class for @hook and @tool decorated methods in declaration order.""" + seen: set[str] = set() + # Walk MRO so parent class hooks come first, child overrides win + for cls in reversed(type(self).__mro__): + for name in cls.__dict__: + if name in seen: + continue + seen.add(name) + + # Get the bound method from self + try: + bound = getattr(self, name) + except Exception: + continue + + # Check for @hook decorated methods + if hasattr(bound, "_hook_event_types") and callable(bound): + self._hooks.append(bound) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) + + # Check for @tool decorated methods (DecoratedFunctionTool instances) + if isinstance(bound, DecoratedFunctionTool): + self._tools.append(bound) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) + def init_agent(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the agent instance. + Override this method to add custom initialization logic. Decorated + hooks and tools are automatically registered by the plugin registry. + Args: agent: The agent instance to initialize. """ - ... + return None diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index 3b8a0a45f..a75858680 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -24,6 +24,11 @@ class _PluginRegistry: The _PluginRegistry tracks plugins that have been initialized with an agent, providing methods to add plugins and invoke their initialization. + The registry handles: + 1. Calling the plugin's init_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the agent + 3. Auto-registering discovered @tool decorated methods with the agent + Example: ```python registry = _PluginRegistry(agent) @@ -31,7 +36,12 @@ class _PluginRegistry: class MyPlugin(Plugin): name = "my-plugin" + @hook + def on_event(self, event: BeforeModelCallEvent): + pass # Auto-registered by registry + def init_agent(self, agent: Agent) -> None: + # Custom logic only - no super() needed pass plugin = MyPlugin() @@ -51,7 +61,12 @@ def __init__(self, agent: "Agent") -> None: def add_and_init(self, plugin: Plugin) -> None: """Add and initialize a plugin with the agent. - This method registers the plugin and calls its init_agent method. + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the agent's hook registry + 4. Auto-registers all discovered @tool methods with the agent's tool registry + Handles both sync and async init_agent implementations automatically. Args: @@ -66,8 +81,51 @@ def add_and_init(self, plugin: Plugin) -> None: logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) self._plugins[plugin.name] = plugin + # Call user's init_agent for custom initialization if inspect.iscoroutinefunction(plugin.init_agent): async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_agent) run_async(lambda: async_plugin_init(self._agent)) else: plugin.init_agent(self._agent) + + # Auto-register discovered hooks with the agent's hook registry + self._register_hooks(plugin) + + # Auto-register discovered tools with the agent's tool registry + self._register_tools(plugin) + + def _register_hooks(self, plugin: Plugin) -> None: + """Register all discovered hooks from the plugin with the agent. + + Warns if a hook callback is already registered for an event type, + which can happen when init_agent() manually registers a hook that + is also decorated with @hook. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._agent.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + def _register_tools(self, plugin: Plugin) -> None: + """Register all discovered tools from the plugin with the agent. + + Args: + plugin: The plugin whose tools should be registered. + """ + if plugin.tools: + self._agent.tool_registry.process_tools(list(plugin.tools)) + for tool in plugin.tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + plugin.name, + tool.tool_name, + ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 6d73fc177..967a0dafb 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -14,7 +14,7 @@ from pydantic import BaseModel import strands -from strands import Agent, ToolContext +from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager @@ -2627,6 +2627,8 @@ def test_agent_plugins_sync_initialization(): """Test that plugins with sync init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "test-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] plugin_mock.init_agent = unittest.mock.Mock() agent = Agent( @@ -2641,6 +2643,8 @@ def test_agent_plugins_async_initialization(): """Test that plugins with async init_agent are initialized correctly.""" plugin_mock = unittest.mock.Mock() plugin_mock.name = "async-plugin" + plugin_mock.hooks = [] + plugin_mock.tools = [] plugin_mock.init_agent = unittest.mock.AsyncMock() agent = Agent( @@ -2657,10 +2661,14 @@ def test_agent_plugins_multiple_in_order(): plugin1 = unittest.mock.Mock() plugin1.name = "plugin1" + plugin1.hooks = [] + plugin1.tools = [] plugin1.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin1")) plugin2 = unittest.mock.Mock() plugin2.name = "plugin2" + plugin2.hooks = [] + plugin2.tools = [] plugin2.init_agent = unittest.mock.Mock(side_effect=lambda agent: call_order.append("plugin2")) Agent( @@ -2675,7 +2683,7 @@ def test_agent_plugins_can_register_hooks(): """Test that plugins can register hooks during initialization.""" hook_called = [] - class TestPlugin: + class TestPlugin(Plugin): name = "hook-plugin" def init_agent(self, agent): diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 90064ea98..1f247120a 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,5 +1,6 @@ """Unit tests for steering handler base class.""" +import inspect from unittest.mock import AsyncMock, Mock import pytest @@ -8,6 +9,7 @@ from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from strands.hooks.registry import HookRegistry from strands.plugins import Plugin @@ -38,15 +40,24 @@ def test_steering_handler_is_plugin(): def test_init_agent(): - """Test init_agent registers hooks on agent.""" + """Test init_agent with plugin registry registers hooks on agent.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Verify hooks were registered (tool and model steering hooks) + # Verify hooks were registered (tool and model steering hooks via @hook decorator) assert agent.add_hook.call_count >= 2 - agent.add_hook.assert_any_call(handler._provide_tool_steering_guidance, BeforeToolCallEvent) + # Check that the decorated hook methods were registered + assert BeforeToolCallEvent in agent.hooks._registered_callbacks + assert AfterModelCallEvent in agent.hooks._registered_callbacks def test_steering_context_initialization(): @@ -86,7 +97,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -105,7 +116,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." @@ -126,7 +137,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -145,7 +156,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -165,7 +176,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) def test_init_agent_override(): @@ -218,62 +229,86 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): - """Test that handler registers hooks from context callbacks.""" + """Test that handler registers hooks from context callbacks via registry.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register hooks for context callback and steering guidance + # Should register hooks for context callback (via init_agent) and steering guidance (via @hook) + # init_agent registers context callbacks manually, @hook decorated methods are auto-registered assert agent.add_hook.call_count >= 2 - # Check that BeforeToolCallEvent was registered - call_args = [call[0] for call in agent.add_hook.call_args_list] - event_types = [args[1] for args in call_args] + # Check that BeforeToolCallEvent was registered (both context callback and steering guidance) + assert BeforeToolCallEvent in agent.hooks._registered_callbacks - assert BeforeToolCallEvent in event_types - -def test_context_callbacks_receive_steering_context(): +@pytest.mark.asyncio +async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Get the registered callback for BeforeToolCallEvent - before_callback = None - for call in agent.add_hook.call_args_list: - if call[0][1] == BeforeToolCallEvent: - before_callback = call[0][0] - break + # Get the registered callbacks for BeforeToolCallEvent + callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) + assert len(callbacks) > 0 - assert before_callback is not None - - # Create a mock event and call the callback + # The context callback is wrapped in a lambda, so we just call all callbacks + # and check if the steering context was updated event = Mock(spec=BeforeToolCallEvent) event.tool_use = {"name": "test_tool", "input": {}} - # The callback should execute without error and update the steering context - before_callback(event) + # Call all callbacks, handling both sync and async + for cb in callbacks: + try: + result = await cb(event) + if inspect.iscoroutine(result): + await result + except Exception: + pass # Some callbacks might be async or have other requirements - # Verify the steering context was updated + # Verify the steering context was updated by at least one callback assert handler.steering_context.data.get("test_key") == "test_value" def test_multiple_context_callbacks_registered(): - """Test that multiple context callbacks are registered.""" + """Test that multiple context callbacks are registered via registry.""" + from strands.plugins.registry import _PluginRegistry + callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_agent(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register one callback for each context provider plus tool and model steering guidance + # Should register: + # - 2 callbacks for context providers (via init_agent manual registration) + # - 2 for steering guidance (via @hook decorator auto-registration) expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert agent.add_hook.call_count >= expected_calls @@ -310,7 +345,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should not set retry for Proceed assert event.retry is False @@ -334,7 +369,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should set retry flag assert event.retry is True @@ -362,7 +397,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event = Mock(spec=AfterModelCallEvent) event.stop_response = None - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # steer_after_model should not have been called assert handler.steer_called is False @@ -386,7 +421,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -407,7 +442,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -429,7 +464,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.retry = False # Should not raise, just return early - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # retry should not be set since exception occurred assert event.retry is False @@ -449,7 +484,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) # Should not raise, just return early - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # cancel_tool should not be set since exception occurred assert not event.cancel_tool @@ -486,11 +521,20 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_agent_registers_model_steering(): - """Test that init_agent registers model steering callback.""" + """Test that model steering hook is registered via plugin registry.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() - - handler.init_agent(agent) - - # Verify model steering hook was registered - agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) + + # Verify model steering hook was registered via @hook decorator + assert AfterModelCallEvent in agent.hooks._registered_callbacks + callbacks = agent.hooks._registered_callbacks[AfterModelCallEvent] + assert len(callbacks) == 1 diff --git a/tests/strands/plugins/test_hook_decorator.py b/tests/strands/plugins/test_hook_decorator.py new file mode 100644 index 000000000..d05e79edb --- /dev/null +++ b/tests/strands/plugins/test_hook_decorator.py @@ -0,0 +1,243 @@ +"""Tests for the @hook decorator.""" + +import unittest.mock + +import pytest + +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, +) +from strands.plugins.decorator import hook + + +class TestHookDecoratorBasic: + """Tests for basic @hook decorator functionality.""" + + def test_hook_decorator_marks_method(self): + """Test that @hook marks a method with hook metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_with_parentheses(self): + """Test that @hook() syntax also works.""" + + @hook() + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_preserves_function_metadata(self): + """Test that @hook preserves the original function's metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + """Docstring for the hook.""" + pass + + assert on_before_model_call.__name__ == "on_before_model_call" + assert on_before_model_call.__doc__ == "Docstring for the hook." + + def test_hook_decorator_function_still_callable(self): + """Test that decorated function can still be called normally.""" + call_count = 0 + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + on_before_model_call(mock_event) + assert call_count == 1 + + +class TestHookDecoratorEventTypeInference: + """Tests for event type inference from type hints.""" + + def test_hook_infers_event_type_from_type_hint(self): + """Test that @hook infers event type from the first parameter's type hint.""" + + @hook + def handler(event: BeforeInvocationEvent): + pass + + assert BeforeInvocationEvent in handler._hook_event_types + + def test_hook_infers_different_event_types(self): + """Test that different event types are correctly inferred.""" + + @hook + def handler1(event: BeforeModelCallEvent): + pass + + @hook + def handler2(event: AfterModelCallEvent): + pass + + @hook + def handler3(event: AfterInvocationEvent): + pass + + assert BeforeModelCallEvent in handler1._hook_event_types + assert AfterModelCallEvent in handler2._hook_event_types + assert AfterInvocationEvent in handler3._hook_event_types + + def test_hook_skips_cls_parameter(self): + """Test that @hook skips 'cls' parameter for classmethods.""" + + class MyClass: + @classmethod + @hook + def handler(cls, event: BeforeModelCallEvent): + pass + + assert BeforeModelCallEvent in MyClass.handler._hook_event_types + + +class TestHookDecoratorUnionTypes: + """Tests for union type support in @hook decorator.""" + + def test_hook_supports_union_types_with_pipe(self): + """Test that @hook supports union types using | syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_union_types_with_typing_union(self): + """Test that @hook supports Union[] syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_multiple_union_types(self): + """Test that @hook supports unions with more than two types.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + assert BeforeInvocationEvent in handler._hook_event_types + + +class TestHookDecoratorErrorHandling: + """Tests for error handling in @hook decorator.""" + + def test_hook_raises_error_without_type_hint(self): + """Test that @hook raises error when no type hint is provided.""" + with pytest.raises(ValueError, match="cannot infer event type"): + + @hook + def handler(event): + pass + + def test_hook_raises_error_with_non_hook_event_type(self): + """Test that @hook raises error when type hint is not a HookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def handler(event: str): + pass + + def test_hook_raises_error_with_none_in_union(self): + """Test that @hook raises error when union contains None.""" + with pytest.raises(ValueError, match="None is not a valid event type"): + + @hook + def handler(event: BeforeModelCallEvent | None): + pass + + +class TestHookDecoratorWithMethods: + """Tests for @hook decorator on class methods.""" + + def test_hook_works_on_instance_method(self): + """Test that @hook works correctly on instance methods.""" + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + pass + + instance = MyClass() + assert hasattr(instance.handler, "_hook_event_types") + assert BeforeModelCallEvent in instance.handler._hook_event_types + + def test_hook_instance_method_is_callable(self): + """Test that decorated instance method can be called.""" + call_count = 0 + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert call_count == 1 + + def test_hook_method_accesses_self(self): + """Test that decorated method can access self.""" + + class MyClass: + def __init__(self): + self.events_received = [] + + @hook + def handler(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert len(instance.events_received) == 1 + assert instance.events_received[0] is mock_event + + +class TestHookDecoratorAsync: + """Tests for async functions with @hook decorator.""" + + def test_hook_works_on_async_function(self): + """Test that @hook works on async functions.""" + + @hook + async def handler(event: BeforeModelCallEvent): + pass + + assert hasattr(handler, "_hook_event_types") + assert BeforeModelCallEvent in handler._hook_event_types + + @pytest.mark.asyncio + async def test_hook_async_function_is_callable(self): + """Test that decorated async function can be awaited.""" + call_count = 0 + + @hook + async def handler(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + await handler(mock_event) + assert call_count == 1 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py new file mode 100644 index 000000000..dab3e7210 --- /dev/null +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -0,0 +1,553 @@ +"""Tests for the Plugin base class with auto-discovery.""" + +import unittest.mock + +import pytest + +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.registry import _PluginRegistry +from strands.tools.decorator import tool + + +def _configure_mock_agent_with_hooks(): + """Helper to create a mock agent with working add_hook.""" + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.add_hook.side_effect = lambda callback, event_type=None: mock_agent.hooks.add_callback( + event_type, callback + ) + return mock_agent + + +class TestPluginBaseClass: + """Tests for Plugin base class basics.""" + + def test_plugin_is_class_not_protocol(self): + """Test that Plugin is now a class, not a Protocol.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + def test_plugin_requires_name_attribute(self): + """Test that Plugin subclass must have name attribute.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert plugin.name == "my-plugin" + + def test_plugin_name_as_property(self): + """Test that Plugin name can be a property.""" + + class MyPlugin(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = MyPlugin() + assert plugin.name == "property-plugin" + + +class TestPluginAutoDiscovery: + """Tests for automatic discovery of decorated methods.""" + + def test_plugin_discovers_hook_decorated_methods(self): + """Test that Plugin.__init__ discovers @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_model" + + def test_plugin_discovers_multiple_hooks(self): + """Test that Plugin discovers multiple @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + hook_names = {h.__name__ for h in plugin.hooks} + assert "hook1" in hook_names + assert "hook2" in hook_names + + def test_hooks_preserve_definition_order(self): + """Test that hooks are discovered in definition order, not alphabetical.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def z_last_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def a_first_alphabetically(self, event: BeforeModelCallEvent): + pass + + @hook + def m_middle_alphabetically(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 3 + # Should be in definition order, not alphabetical + assert plugin.hooks[0].__name__ == "z_last_alphabetically" + assert plugin.hooks[1].__name__ == "a_first_alphabetically" + assert plugin.hooks[2].__name__ == "m_middle_alphabetically" + + def test_plugin_discovers_tool_decorated_methods(self): + """Test that Plugin.__init__ discovers @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "my_tool" + + def test_plugin_discovers_both_hooks_and_tools(self): + """Test that Plugin discovers both @hook and @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert len(plugin.tools) == 1 + + def test_plugin_ignores_non_decorated_methods(self): + """Test that Plugin doesn't discover non-decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + def test_hooks_property_returns_list(self): + """Test that hooks property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert isinstance(plugin.hooks, list) + + def test_tools_property_returns_list(self): + """Test that tools property returns a mutable list.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert isinstance(plugin.tools, list) + + def test_hooks_can_be_filtered(self): + """Test that hooks list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + + # Filter out hook1 + plugin.hooks[:] = [h for h in plugin.hooks if h.__name__ != "hook1"] + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "hook2" + + def test_tools_can_be_filtered(self): + """Test that tools list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def tool1(self, param: str) -> str: + """Tool 1.""" + return param + + @tool + def tool2(self, param: str) -> str: + """Tool 2.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 2 + + # Filter out tool1 + plugin.tools[:] = [t for t in plugin.tools if t.tool_name != "tool1"] + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "tool2" + + +class TestPluginRegistryAutoRegistration: + """Tests for auto-registration via _PluginRegistry.""" + + def test_registry_registers_hooks_with_agent(self): + """Test that _PluginRegistry registers discovered hooks with agent.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_registry_registers_tools_with_agent(self): + """Test that _PluginRegistry adds discovered tools to agent's tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify tool was added to agent + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_registers_both_hooks_and_tools(self): + """Test that _PluginRegistry registers both hooks and tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify both registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_registry_calls_init_agent_before_registration(self): + """Test that _PluginRegistry calls init_agent for custom logic.""" + init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal init_called + init_called = True + # Custom logic - no super() needed + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert init_called + # Verify auto-registration still happened + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginHookWithUnionTypes: + """Tests for Plugin hooks with union types.""" + + def test_registry_registers_hook_for_union_types(self): + """Test that hooks with union types are registered for all event types.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): + pass + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify hook was registered for both event types + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginMultipleAgents: + """Tests for plugin reuse with multiple agents.""" + + def test_plugin_can_be_attached_to_multiple_agents(self): + """Test that the same plugin instance can be used with multiple agents.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + + mock_agent1 = _configure_mock_agent_with_hooks() + mock_agent2 = _configure_mock_agent_with_hooks() + + # Note: In practice, different registries would be used for each agent + # Here we simulate attaching to multiple agents directly + registry1 = _PluginRegistry(mock_agent1) + registry1.add_and_init(plugin) + + # Create new plugin instance for second agent (same class) + plugin2 = MyPlugin() + registry2 = _PluginRegistry(mock_agent2) + registry2.add_and_init(plugin2) + + # Verify both agents have the hook registered + assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent2.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginSubclassOverride: + """Tests for subclass overriding init_agent.""" + + def test_subclass_can_override_init_agent_without_super(self): + """Test that subclass can override init_agent without calling super().""" + custom_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + def init_agent(self, agent): + nonlocal custom_init_called + custom_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert custom_init_called + # Verify auto-registration still happened via registry + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_subclass_can_add_manual_hooks(self): + """Test that subclass can manually add hooks in addition to decorated ones.""" + manual_hook_added = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeModelCallEvent): + pass + + def manual_hook(self, event: BeforeInvocationEvent): + pass + + def init_agent(self, agent): + nonlocal manual_hook_added + # Add manual hook - no super() needed + agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) + manual_hook_added = True + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert manual_hook_added + # Verify both hooks registered (1 manual + 1 auto) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginAsyncInitPlugin: + """Tests for async init_agent support.""" + + @pytest.mark.asyncio + async def test_async_init_agent_supported(self): + """Test that async init_agent is supported.""" + async_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + async def init_agent(self, agent): + nonlocal async_init_called + async_init_called = True + # No super() needed - registry handles auto-registration + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Verify async init was called (run_async handles it) + assert async_init_called + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginBoundMethods: + """Tests for bound method registration.""" + + def test_hooks_are_bound_to_instance(self): + """Test that registered hooks are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Call the registered hook and verify it accesses the correct instance + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + callbacks = list(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) + callbacks[0](mock_event) + + assert len(plugin.events_received) == 1 + assert plugin.events_received[0] is mock_event + + def test_tools_are_bound_to_instance(self): + """Test that registered tools are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.tool_called = False + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + self.tool_called = True + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + # Get the tool that was registered and call it + call_args = mock_agent.tool_registry.process_tools.call_args + registered_tools = call_args[0][0] + assert len(registered_tools) == 1 + + # Call the tool - it should be bound to the instance + result = registered_tools[0]("test") + assert plugin.tool_called + assert result == "test" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index c16cfcf7a..04b39718b 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -4,38 +4,39 @@ import pytest +from strands.hooks import HookRegistry from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry -# Plugin Tests +# Plugin Base Class Tests -def test_plugin_class_requires_inheritance(): - """Test that Plugin class requires inheritance.""" +def test_plugin_base_class_isinstance_check(): + """Test that Plugin subclass passes isinstance check.""" class MyPlugin(Plugin): name = "my-plugin" - def init_agent(self, agent): - pass - plugin = MyPlugin() assert isinstance(plugin, Plugin) -def test_plugin_class_sync_implementation(): - """Test Plugin class works with synchronous init_agent.""" +def test_plugin_base_class_sync_implementation(): + """Test Plugin base class works with synchronous init_agent.""" class SyncPlugin(Plugin): name = "sync-plugin" def init_agent(self, agent): + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,19 +46,22 @@ def init_agent(self, agent): @pytest.mark.asyncio -async def test_plugin_class_async_implementation(): - """Test Plugin class works with asynchronous init_agent.""" +async def test_plugin_base_class_async_implementation(): + """Test Plugin base class works with asynchronous init_agent.""" class AsyncPlugin(Plugin): name = "async-plugin" async def init_agent(self, agent): + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -78,42 +82,37 @@ def init_agent(self, agent): PluginWithoutName() -def test_plugin_class_requires_init_agent_method(): - """Test that Plugin class requires an init_agent method.""" +def test_plugin_base_class_requires_init_agent_method(): + """Test that Plugin base class provides default init_agent.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): + class PluginWithoutOverride(Plugin): + name = "no-override-plugin" - class PluginWithoutInitPlugin(Plugin): - name = "incomplete-plugin" + plugin = PluginWithoutOverride() + # Plugin base class provides default init_agent + assert hasattr(plugin, "init_agent") + assert callable(plugin.init_agent) - PluginWithoutInitPlugin() - -def test_plugin_class_with_class_attribute_name(): - """Test Plugin class works when name is a class attribute.""" +def test_plugin_base_class_with_class_attribute_name(): + """Test Plugin base class works when name is a class attribute.""" class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" - def init_agent(self, agent): - pass - plugin = PluginWithClassAttribute() assert isinstance(plugin, Plugin) assert plugin.name == "class-attr-plugin" -def test_plugin_class_with_property_name(): - """Test Plugin class works when name is a property.""" +def test_plugin_base_class_with_property_name(): + """Test Plugin base class works when name is a property.""" class PluginWithProperty(Plugin): @property - def name(self): + def name(self) -> str: return "property-plugin" - def init_agent(self, agent): - pass - plugin = PluginWithProperty() assert isinstance(plugin, Plugin) assert plugin.name == "property-plugin" @@ -125,7 +124,11 @@ def init_agent(self, agent): @pytest.fixture def mock_agent(): """Create a mock agent for testing.""" - return unittest.mock.Mock() + agent = unittest.mock.Mock() + agent.hooks = HookRegistry() + agent.tool_registry = unittest.mock.MagicMock() + agent.add_hook = unittest.mock.Mock() + return agent @pytest.fixture @@ -141,9 +144,11 @@ class TestPlugin(Plugin): name = "test-plugin" def __init__(self): + super().__init__() self.initialized = False def init_agent(self, agent): + # No super() needed - registry handles auto-registration self.initialized = True agent.plugin_initialized = True @@ -160,9 +165,6 @@ def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): class TestPlugin(Plugin): name = "test-plugin" - def init_agent(self, agent): - pass - plugin1 = TestPlugin() plugin2 = TestPlugin() @@ -179,9 +181,11 @@ class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): + super().__init__() self.initialized = False async def init_agent(self, agent): + # No super() needed - registry handles auto-registration self.initialized = True agent.async_plugin_initialized = True From 9143e23f2da7c49f2463dc61620559c38180f9d5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:39:08 -0500 Subject: [PATCH 160/279] ci: bump actions/upload-artifact from 6 to 7 (#1777) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- .github/workflows/pypi-publish-on-release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index e7cdbe131..7d0acc0ec 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -63,7 +63,7 @@ jobs: - name: Upload test results if: always() - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: test-results path: ./build/test-results.xml diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 7c96a9789..e3ca847a6 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -56,7 +56,7 @@ jobs: hatch build - name: Store the distribution packages - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: python-package-distributions path: dist/ From 4cd7eebba44cdda9387e3ed11c3782257c10036b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:41:15 +0000 Subject: [PATCH 161/279] ci: bump actions/download-artifact from 7 to 8 (#1776) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/integration-test.yml | 2 +- .github/workflows/pypi-publish-on-release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 7d0acc0ec..5f7dd20d9 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -92,7 +92,7 @@ jobs: persist-credentials: false - name: Download test results - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: test-results diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index e3ca847a6..4601d4069 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -78,7 +78,7 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: python-package-distributions path: dist/ From 3625d7d9777e39cb1dc82bfbf8f62549a82ac423 Mon Sep 17 00:00:00 2001 From: Charles Duffy Date: Tue, 3 Mar 2026 15:34:38 -0600 Subject: [PATCH 162/279] fix: throw exceptions from ConcurrentToolExecutor (#1797) Co-authored-by: Patrick Gray --- src/strands/tools/executors/concurrent.py | 62 +++++++++++-------- tests/strands/tools/executors/conftest.py | 14 ++++- .../tools/executors/test_concurrent.py | 29 ++++++++- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 7fa34eff0..835e5abff 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -48,34 +48,43 @@ async def _execute( task_events = [asyncio.Event() for _ in tool_uses] stop_event = object() - tasks = [ - asyncio.create_task( - self._task( - agent, - tool_use, - tool_results, - cycle_trace, - cycle_span, - invocation_state, - task_id, - task_queue, - task_events[task_id], - stop_event, - structured_output_context, + tasks = [] + try: + for task_id, tool_use in enumerate(tool_uses): + tasks.append( + asyncio.create_task( + self._task( + agent, + tool_use, + tool_results, + cycle_trace, + cycle_span, + invocation_state, + task_id, + task_queue, + task_events[task_id], + stop_event, + structured_output_context, + ) + ) ) - ) - for task_id, tool_use in enumerate(tool_uses) - ] - task_count = len(tasks) - while task_count: - task_id, event = await task_queue.get() - if event is stop_event: - task_count -= 1 - continue + task_count = len(tasks) + while task_count: + task_id, event = await task_queue.get() + if event is stop_event: + task_count -= 1 + continue - yield event - task_events[task_id].set() + if isinstance(event, Exception): + raise event + + yield event + task_events[task_id].set() + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) async def _task( self, @@ -115,5 +124,8 @@ async def _task( await task_event.wait() task_event.clear() + except Exception as e: + task_queue.put_nowait((task_id, e)) + finally: task_queue.put_nowait((task_id, stop_event)) diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index ad92ba603..8ecbe2f88 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -1,3 +1,4 @@ +import asyncio import threading import unittest.mock @@ -90,13 +91,24 @@ def func(tool_context: ToolContext) -> str: @pytest.fixture -def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool): +def slow_tool(): + @strands.tool(name="slow_tool") + async def func(): + """A tool that blocks until cancelled.""" + await asyncio.sleep(3) + + return func + + +@pytest.fixture +def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, interrupt_tool, slow_tool): registry = ToolRegistry() registry.register_tool(weather_tool) registry.register_tool(temperature_tool) registry.register_tool(exception_tool) registry.register_tool(thread_tool) registry.register_tool(interrupt_tool) + registry.register_tool(slow_tool) return registry diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index ce07ee4ce..a8ac05830 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,6 +1,6 @@ import pytest -from strands.hooks import BeforeToolCallEvent +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor from strands.tools.structured_output._structured_output_context import StructuredOutputContext @@ -76,3 +76,30 @@ def interrupt_callback(event): tru_results = tool_results exp_results = [exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_concurrent_executor_reraises_exceptions( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist +): + """Test that hook re-raised exceptions propagate and cancel remaining tasks.""" + + def reraise_callback(event): + if event.exception is not None: + raise event.exception + + agent.hooks.add_callback(AfterToolCallEvent, reraise_callback) + + tool_uses = [ + {"name": "exception_tool", "toolUseId": "1", "input": {}}, + {"name": "slow_tool", "toolUseId": "2", "input": {}}, + ] + + stream = executor._execute( + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context + ) + + with pytest.raises(RuntimeError, match="Tool error"): + await alist(stream) + + assert tool_results == [] From 31f1e649320b71686eb86fb8ee61aaf534bab06b Mon Sep 17 00:00:00 2001 From: Gitika <53349492+notgitika@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:13:09 -0500 Subject: [PATCH 163/279] feat: add OpenAI Responses API model implementation (#975) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- README.md | 1 + src/strands/models/__init__.py | 4 + src/strands/models/openai_responses.py | 691 ++++++++++++++ tests/strands/models/conftest.py | 25 + tests/strands/models/test_openai_responses.py | 885 ++++++++++++++++++ tests_integ/models/providers.py | 43 +- tests_integ/models/test_model_openai.py | 57 +- 7 files changed, 1685 insertions(+), 21 deletions(-) create mode 100644 src/strands/models/openai_responses.py create mode 100644 tests/strands/models/conftest.py create mode 100644 tests/strands/models/test_openai_responses.py diff --git a/README.md b/README.md index 9ee7f6c56..fdb309f99 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,7 @@ Built-in providers: - [MistralAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/mistral/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) + - [OpenAI Responses API](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - [SageMaker](https://strandsagents.com/latest/user-guide/concepts/model-providers/sagemaker/) - [Writer](https://strandsagents.com/latest/user-guide/concepts/model-providers/writer/) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index be6a96549..2c582d116 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -55,6 +55,10 @@ def __getattr__(name: str) -> Any: from .openai import OpenAIModel return OpenAIModel + if name == "OpenAIResponsesModel": + from .openai_responses import OpenAIResponsesModel + + return OpenAIResponsesModel if name == "SageMakerAIModel": from .sagemaker import SageMakerAIModel diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py new file mode 100644 index 000000000..96d4bee59 --- /dev/null +++ b/src/strands/models/openai_responses.py @@ -0,0 +1,691 @@ +"""OpenAI model provider using the Responses API. + +The Responses API is OpenAI's newer API that differs from the Chat Completions API in several key ways: + +1. The Responses API can maintain conversation state server-side through "previous_response_id", + while Chat Completions is stateless and requires sending full conversation history each time. + Note: This implementation currently only implements the stateless approach. + +2. Responses API uses "input" (list of items) instead of "messages", and system + prompts are passed as "instructions" rather than a system role message. + +3. Responses API supports built-in tools (web search, code interpreter, file search) + Note: These are not yet implemented in this provider. + +- Docs: https://platform.openai.com/docs/api-reference/responses +""" + +import base64 +import json +import logging +import mimetypes +from collections.abc import AsyncGenerator +from importlib.metadata import version as get_package_version +from types import SimpleNamespace +from typing import Any, Protocol, TypedDict, TypeVar, cast + +from packaging.version import Version +from pydantic import BaseModel +from typing_extensions import Unpack, override + +# Validate OpenAI SDK version at import time - Responses API requires v2.0.0+ +# A major version bump is proposed in https://github.com/strands-agents/sdk-python/pull/1370 +_MIN_OPENAI_VERSION = Version("2.0.0") + +try: + _openai_version = Version(get_package_version("openai")) + if _openai_version < _MIN_OPENAI_VERSION: + raise ImportError( + f"OpenAIResponsesModel requires openai>={_MIN_OPENAI_VERSION} (found {_openai_version}). " + "Install/upgrade with: pip install -U openai. " + "For older SDKs, use OpenAIModel (Chat Completions)." + ) +except ImportError: + # Re-raise ImportError as-is (covers both our explicit raise above and missing openai package) + raise +except Exception as e: + raise ImportError( + f"OpenAIResponsesModel requires openai>={_MIN_OPENAI_VERSION}. Install with: pip install -U openai" + ) from e + +import openai # noqa: E402 - must import after version check + +from ..types.content import ContentBlock, Messages # noqa: E402 +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 +from ..types.streaming import StreamEvent # noqa: E402 +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._validation import validate_config_keys # noqa: E402 +from .model import Model # noqa: E402 + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + +# Maximum file size for media content in tool results (20MB) +_MAX_MEDIA_SIZE_BYTES = 20 * 1024 * 1024 +_MAX_MEDIA_SIZE_LABEL = "20MB" +_DEFAULT_MIME_TYPE = "application/octet-stream" +_CONTEXT_WINDOW_OVERFLOW_MSG = "OpenAI Responses API threw context window overflow error" +_RATE_LIMIT_MSG = "OpenAI Responses API threw rate limit error" + + +def _encode_media_to_data_url(data: bytes, format_ext: str, media_type: str = "image") -> str: + """Encode media bytes to a base64 data URL with size validation. + + Args: + data: Raw bytes of the media content. + format_ext: File format extension (e.g., "png", "pdf"). + media_type: Type of media for error messages ("image" or "document"). + + Returns: + Base64-encoded data URL string. + + Raises: + ValueError: If the media size exceeds the maximum allowed size. + """ + if len(data) > _MAX_MEDIA_SIZE_BYTES: + raise ValueError( + f"{media_type.capitalize()} size {len(data)} bytes exceeds maximum of" + f" {_MAX_MEDIA_SIZE_BYTES} bytes ({_MAX_MEDIA_SIZE_LABEL})" + ) + mime_type = mimetypes.types_map.get(f".{format_ext}", _DEFAULT_MIME_TYPE) + encoded_data = base64.b64encode(data).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + + +class _ToolCallInfo(TypedDict): + """Internal type for tracking tool call information during streaming.""" + + name: str + arguments: str + call_id: str + item_id: str + + +class Client(Protocol): + """Protocol defining the OpenAI Responses API interface for the underlying provider client.""" + + @property + # pragma: no cover + def responses(self) -> Any: + """Responses interface.""" + ... + + +class OpenAIResponsesModel(Model): + """OpenAI Responses API model provider implementation. + + Note: + This implementation currently only supports function tools (custom tools defined via tool_specs). + OpenAI's built-in system tools are not yet supported. + """ + + client: Client + client_args: dict[str, Any] + + class OpenAIResponsesConfig(TypedDict, total=False): + """Configuration options for OpenAI Responses API models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_output_tokens, temperature, etc.). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/responses/create. + """ + + model_id: str + params: dict[str, Any] | None + + def __init__( + self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig] + ) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI Responses API model. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config = dict(model_config) + self.client_args = client_args or {} + + logger.debug("config=<%s> | initializing", self.config) + + @override + def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override] + """Update the OpenAI Responses API model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIResponsesConfig: + """Get the OpenAI Responses API model configuration. + + Returns: + The OpenAI Responses API model configuration. + """ + return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) + + @override + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the OpenAI Responses API model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + logger.debug("formatting request for OpenAI Responses API") + request = self._format_request(messages, tool_specs, system_prompt, tool_choice) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking OpenAI Responses API model") + + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.create(**request) + + logger.debug("streaming response from OpenAI Responses API model") + + yield self._format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[str, _ToolCallInfo] = {} + final_usage = None + data_type: str | None = None + stop_reason: str | None = None + + async for event in response: + if hasattr(event, "type"): + if event.type == "response.reasoning_text.delta": + # Reasoning content streaming (for o1/o3 reasoning models) + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield chunk + if hasattr(event, "delta") and isinstance(event.delta, str): + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": event.delta, + } + ) + + elif event.type == "response.output_text.delta": + # Text content streaming + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield chunk + if hasattr(event, "delta") and isinstance(event.delta, str): + yield self._format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": event.delta} + ) + + elif event.type == "response.output_item.added": + # Tool call started + if ( + hasattr(event, "item") + and hasattr(event.item, "type") + and event.item.type == "function_call" + ): + call_id = getattr(event.item, "call_id", "unknown") + tool_calls[call_id] = { + "name": getattr(event.item, "name", ""), + "arguments": "", + "call_id": call_id, + "item_id": getattr(event.item, "id", ""), + } + + elif event.type == "response.function_call_arguments.delta": + # Tool arguments streaming - accumulate deltas by item_id + if hasattr(event, "delta") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] += event.delta + break + + elif event.type == "response.function_call_arguments.done": + # Tool arguments complete - use final arguments as source of truth + if hasattr(event, "arguments") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] = event.arguments + break + + elif event.type == "response.incomplete": + # Response stopped early (e.g., max tokens reached) + if hasattr(event, "response"): + if hasattr(event.response, "usage"): + final_usage = event.response.usage + # Check if stopped due to max_output_tokens + if ( + hasattr(event.response, "incomplete_details") + and event.response.incomplete_details + and getattr(event.response.incomplete_details, "reason", None) + == "max_output_tokens" + ): + stop_reason = "length" + break + + elif event.type == "response.completed": + # Response complete + if hasattr(event, "response") and hasattr(event.response, "usage"): + final_usage = event.response.usage + break + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning(_CONTEXT_WINDOW_OVERFLOW_MSG) + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + logger.warning(_RATE_LIMIT_MSG) + raise ModelThrottledException(str(e)) from e + + # Close current content block if we had any + if data_type: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Emit tool calls with complete arguments. + # We emit a single delta per tool containing the full arguments rather than streaming + # incremental argument deltas. The Responses API streams argument chunks via separate + # events (response.function_call_arguments.delta) which we accumulate above, then use + # the final arguments from response.function_call_arguments.done. This approach ensures + # we emit valid, complete JSON arguments rather than partial fragments. + for call_info in tool_calls.values(): + tool_call = SimpleNamespace( + function=SimpleNamespace(name=call_info["name"], arguments=call_info["arguments"]), + id=call_info["call_id"], + ) + + yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + # Determine finish reason: tool_calls > max_tokens (length) > normal stop + if tool_calls: + finish_reason = "tool_calls" + elif stop_reason == "length": + finish_reason = "length" + else: + finish_reason = "stop" + yield self._format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + if final_usage: + yield self._format_chunk({"chunk_type": "metadata", "data": final_usage}) + + logger.debug("finished streaming response from OpenAI Responses API model") + + @override + async def structured_output( + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: + """Get structured output from the OpenAI Responses API model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.parse( + model=self.get_config()["model_id"], + input=self._format_request(prompt, system_prompt=system_prompt)["input"], + text_format=output_model, + ) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning(_CONTEXT_WINDOW_OVERFLOW_MSG) + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + logger.warning(_RATE_LIMIT_MSG) + raise ModelThrottledException(str(e)) from e + + if response.output_parsed: + yield {"output": response.output_parsed} + else: + raise ValueError("No valid parsed output found in the OpenAI Responses API response.") + + def _format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format an OpenAI Responses API compatible response streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Returns: + An OpenAI Responses API compatible response streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + input_items = self._format_request_messages(messages) + request = { + "model": self.config["model_id"], + "input": input_items, + "stream": True, + **cast(dict[str, Any], self.config.get("params", {})), + } + + if system_prompt: + request["instructions"] = system_prompt + + # Add tools if provided + if tool_specs: + request["tools"] = [ + { + "type": "function", + "name": tool_spec["name"], + "description": tool_spec.get("description", ""), + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs + ] + # Add tool_choice if provided + request.update(self._format_request_tool_choice(tool_choice)) + + return request + + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Format a tool choice for OpenAI Responses API compatibility. + + Args: + tool_choice: Tool choice configuration. + + Returns: + OpenAI Responses API compatible tool choice format. + """ + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "name": tool_name}} + case _: + # Default to auto for unknown formats + return {"tool_choice": "auto"} + + @classmethod + def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + role = message["role"] + contents = message["content"] + + formatted_contents = [ + cls._format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + if formatted_contents: + formatted_messages.append( + { + "role": role, # "user" | "assistant" + "content": formatted_contents, + } + ) + + formatted_messages.extend(formatted_tool_calls) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message.get("content") or message.get("type") in ["function_call", "function_call_output"] + ] + + @classmethod + def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + """ + if "document" in content: + doc = content["document"] + data_url = _encode_media_to_data_url(doc["source"]["bytes"], doc["format"], "document") + return {"type": "input_file", "file_url": data_url} + + if "image" in content: + img = content["image"] + data_url = _encode_media_to_data_url(img["source"]["bytes"], img["format"], "image") + return {"type": "input_image", "image_url": data_url} + + if "text" in content: + return {"type": "input_text", "text": content["text"]} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "type": "function_call", + "call_id": tool_use["toolUseId"], + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + + Raises: + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + + Note: + The Responses API's function_call_output can be either a string (typically JSON encoded) + or an array of content objects when returning images/files. + See: https://platform.openai.com/docs/guides/function-calling + """ + output_parts: list[dict[str, Any]] = [] + has_media = False + + for content in tool_result["content"]: + if "json" in content: + output_parts.append({"type": "input_text", "text": json.dumps(content["json"])}) + elif "text" in content: + output_parts.append({"type": "input_text", "text": content["text"]}) + elif "image" in content: + has_media = True + img = content["image"] + data_url = _encode_media_to_data_url(img["source"]["bytes"], img["format"], "image") + output_parts.append({"type": "input_image", "image_url": data_url}) + elif "document" in content: + has_media = True + doc = content["document"] + data_url = _encode_media_to_data_url(doc["source"]["bytes"], doc["format"], "document") + output_parts.append({"type": "input_file", "file_url": data_url}) + + # Return array if has media content, otherwise join as string for simpler text-only cases + output: list[dict[str, Any]] | str + if has_media: + output = output_parts + else: + output = "\n".join(part.get("text", "") for part in output_parts) if output_parts else "" + + return { + "type": "function_call_output", + "call_id": tool_result["toolUseId"], + "output": output, + } + + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: + """Handle switching to a new content stream. + + Args: + data_type: The next content data type. + prev_data_type: The previous content data type. + + Returns: + Tuple containing: + - Stop block for previous content and the start block for the next content. + - Next content data type. + """ + chunks: list[StreamEvent] = [] + if data_type != prev_data_type: + if prev_data_type is not None: + chunks.append(self._format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type})) + chunks.append(self._format_chunk({"chunk_type": "content_start", "data_type": data_type})) + + return chunks, data_type + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + # Responses API uses input_tokens/output_tokens naming convention + return { + "metadata": { + "usage": { + "inputTokens": getattr(event["data"], "input_tokens", 0), + "outputTokens": getattr(event["data"], "output_tokens", 0), + "totalTokens": getattr(event["data"], "total_tokens", 0), + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") diff --git a/tests/strands/models/conftest.py b/tests/strands/models/conftest.py new file mode 100644 index 000000000..aaf01a047 --- /dev/null +++ b/tests/strands/models/conftest.py @@ -0,0 +1,25 @@ +"""Pytest configuration for model tests.""" + +import sys +import unittest.mock + +# Mock OpenAI version check before the openai_responses module is imported. +# This is necessary because the version check happens at module import time. +# We patch importlib.metadata.version directly since that's where get_package_version comes from. +if "strands.models.openai_responses" not in sys.modules: + _original_version = None + try: + from importlib.metadata import version as _original_version_func + + _original_version = _original_version_func + except ImportError: + pass + + def _mock_version(package_name: str) -> str: + if package_name == "openai": + return "2.0.0" + if _original_version: + return _original_version(package_name) + raise Exception(f"Package {package_name} not found") + + unittest.mock.patch("importlib.metadata.version", _mock_version).start() diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py new file mode 100644 index 000000000..7b09f1b68 --- /dev/null +++ b/tests/strands/models/test_openai_responses.py @@ -0,0 +1,885 @@ +import unittest.mock + +import openai +import pydantic +import pytest + +import strands +from strands.models.openai_responses import _MAX_MEDIA_SIZE_BYTES, OpenAIResponsesModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + +@pytest.fixture +def model_id(): + return "gpt-4o" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + return OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__(model_id): + model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + tru_config = model.get_config() + exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}} + + assert tru_config == exp_config + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "type": "input_file", + "file_url": "data:application/pdf;base64,ZG9jdW1lbnQ=", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "type": "input_image", + "image_url": "data:image/jpeg;base64,aW1hZ2U=", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "input_text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = OpenAIResponsesModel._format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_message_tool_call(tool_use) + exp_result = { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"text": "4"}, {"json": ["4"]}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + exp_result = { + "type": "function_call_output", + "call_id": "c1", + "output": '4\n["4"]', + } + assert tru_result == exp_result + + +def test_format_request_tool_message_with_image(): + """Test that tool results with images return an array output.""" + tool_result = { + "content": [ + {"text": "Here is the image:"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + "status": "success", + "toolUseId": "c2", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c2" + # When images are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 2 + assert tru_result["output"][0]["type"] == "input_text" + assert tru_result["output"][0]["text"] == "Here is the image:" + assert tru_result["output"][1]["type"] == "input_image" + assert "image_url" in tru_result["output"][1] + + +def test_format_request_tool_message_with_document(): + """Test that tool results with documents return an array output.""" + tool_result = { + "content": [ + {"document": {"format": "pdf", "name": "test.pdf", "source": {"bytes": b"fake_pdf_data"}}}, + ], + "status": "success", + "toolUseId": "c3", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c3" + # When documents are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 1 + assert tru_result["output"][0]["type"] == "input_file" + assert "file_url" in tru_result["output"][0] + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], + "role": "user", + }, + ] + + tru_result = OpenAIResponsesModel._format_request_messages(messages) + exp_result = [ + { + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + }, + { + "role": "assistant", + "content": [{"type": "input_text", "text": "call tool"}], + }, + { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + { + "type": "function_call_output", + "call_id": "c1", + "output": "4", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model._format_request(messages, tool_specs, system_prompt) + exp_request = { + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "test"}], + } + ], + "stream": True, + "instructions": system_prompt, + "tools": [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + ], + "max_output_tokens": 100, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Reasoning Text + ( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "I'm thinking"}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "I'm thinking"}}}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(input_tokens=100, output_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model._format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model._format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream(openai_client, model_id, model, agenerator, alist): + # Mock response events + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + expected_request = { + "model": model_id, + "input": [{"role": "user", "content": [{"type": "input_text", "text": "test"}]}], + "stream": True, + "max_output_tokens": 100, + } + openai_client.responses.create.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): + # Mock tool call events + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_123", name="calculator", id="item_456"), + ) + mock_args_event = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expression": "2+2"}', item_id="item_456" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should include tool call events + assert any("toolUse" in str(event) for event in tru_events) + assert {"messageStop": {"stopReason": "tool_use"}} in tru_events + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls_done_event(openai_client, model, agenerator, alist): + """Test that response.function_call_arguments.done overwrites accumulated deltas.""" + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_1", name="calculator", id="item_1"), + ) + # Simulate partial delta that would produce incomplete JSON + mock_args_delta = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expr', item_id="item_1" + ) + # The done event provides the complete, correct arguments + mock_args_done = unittest.mock.Mock( + type="response.function_call_arguments.done", arguments='{"expression": "2+2"}', item_id="item_1" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_delta, mock_args_done, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + tru_events = await alist(model.stream(messages)) + + # Find the tool use delta event and verify it has the final (done) arguments, not the partial delta + tool_deltas = [e for e in tru_events if "contentBlockDelta" in e and "toolUse" in e["contentBlockDelta"]["delta"]] + assert len(tool_deltas) == 1 + assert tool_deltas[0]["contentBlockDelta"]["delta"]["toolUse"]["input"] == '{"expression": "2+2"}' + + +@pytest.mark.asyncio +async def test_stream_response_incomplete(openai_client, model, agenerator, alist): + """Test that response.incomplete sets stop_reason to length when max_output_tokens is reached.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Truncated resp") + mock_incomplete_event = unittest.mock.Mock( + type="response.incomplete", + response=unittest.mock.Mock( + usage=unittest.mock.Mock(input_tokens=10, output_tokens=100, total_tokens=110), + incomplete_details=unittest.mock.Mock(reason="max_output_tokens"), + ), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_incomplete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "write a long essay"}]}] + tru_events = await alist(model.stream(messages)) + + assert {"messageStop": {"stopReason": "max_tokens"}} in tru_events + # Verify usage was still captured + metadata_events = [e for e in tru_events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"]["usage"]["inputTokens"] == 10 + assert metadata_events[0]["metadata"]["usage"]["outputTokens"] == 100 + + +@pytest.mark.asyncio +async def test_stream_reasoning_content(openai_client, model, agenerator, alist): + """Test that reasoning content (o1/o3 models) is streamed correctly.""" + mock_reasoning_event = unittest.mock.Mock(type="response.reasoning_text.delta", delta="Let me think...") + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=20, total_tokens=30)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_reasoning_event, mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "think step by step"}]}] + tru_events = await alist(model.stream(messages)) + + # Verify reasoning content block was emitted + reasoning_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "reasoningContent" in e["contentBlockDelta"]["delta"] + ] + assert len(reasoning_deltas) == 1 + assert reasoning_deltas[0]["contentBlockDelta"]["delta"]["reasoningContent"]["text"] == "Let me think..." + + # Verify text content block was also emitted + text_deltas = [e for e in tru_events if "contentBlockDelta" in e and "text" in e["contentBlockDelta"]["delta"]] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "The answer is 42" + + # Verify content blocks were properly opened and closed (reasoning start/stop, then text start/stop) + content_starts = [e for e in tru_events if "contentBlockStart" in e] + content_stops = [e for e in tru_events if "contentBlockStop" in e] + assert len(content_starts) == 2 # one for reasoning, one for text + assert len(content_stops) == 2 + + +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_response = unittest.mock.Mock(output_parsed=mock_parsed_instance) + + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): + """Test that rate limit errors are converted to ModelThrottledException.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_bad_request_non_context_overflow(openai_client, model, messages): + """Test that non-context-overflow BadRequestErrors are re-raised.""" + mock_error = openai.BadRequestError( + message="Invalid request format", + response=unittest.mock.MagicMock(), + body={"error": {"code": "invalid_request"}}, + ) + mock_error.code = "invalid_request" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(openai.BadRequestError) as exc_info: + async for _ in model.stream(messages): + pass + + assert exc_info.value == mock_error + + +@pytest.mark.asyncio +async def test_stream_error_during_iteration(openai_client, model, messages, agenerator): + """Test that errors during streaming iteration are properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + raise openai.RateLimitError( + message="Rate limit during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit during stream" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_context_overflow_during_iteration(openai_client, model, messages): + """Test that context overflow during streaming iteration is properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + error = openai.BadRequestError( + message="Context length exceeded during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + error.code = "context_length_exceeded" + raise error + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context length exceeded" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles context overflow properly.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles rate limit errors properly.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_bad_request_non_context_overflow( + openai_client, model, messages, test_output_model_cls +): + """Test that structured output re-raises non-context-overflow BadRequestErrors.""" + mock_error = openai.BadRequestError( + message="Invalid request format", + response=unittest.mock.MagicMock(), + body={"error": {"code": "invalid_request"}}, + ) + mock_error.code = "invalid_request" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(openai.BadRequestError) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert exc_info.value == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_no_parsed_output(openai_client, model, messages, test_output_model_cls, alist): + """Test that structured output raises ValueError when output_parsed is None.""" + mock_response = unittest.mock.Mock(output_parsed=None) + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + with pytest.raises(ValueError, match="No valid parsed output"): + await alist(model.structured_output(test_output_model_cls, messages)) + + +@pytest.mark.asyncio +async def test_stream_with_empty_tool_result_content(model): + """Test formatting tool result with empty content list.""" + tool_result = { + "content": [], + "status": "success", + "toolUseId": "c1", + } + + result = OpenAIResponsesModel._format_request_tool_message(tool_result) + assert result["output"] == "" + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIResponsesModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) + + +@pytest.mark.parametrize( + ("tool_choice", "expected"), + [ + (None, {}), + ({"auto": {}}, {"tool_choice": "auto"}), + ({"any": {}}, {"tool_choice": "required"}), + ({"tool": {"name": "calculator"}}, {"tool_choice": {"type": "function", "name": "calculator"}}), + ({"unknown": {}}, {"tool_choice": "auto"}), # Test default fallback + ], +) +def test_format_request_tool_choice(tool_choice, expected): + """Test that tool_choice is properly formatted for the Responses API.""" + result = OpenAIResponsesModel._format_request_tool_choice(tool_choice) + assert result == expected + + +def test_format_request_with_tool_choice(model, messages, tool_specs): + """Test that tool_choice is properly included in the request.""" + tool_choice = {"tool": {"name": "test_tool"}} + request = model._format_request(messages, tool_specs, tool_choice=tool_choice) + + assert "tool_choice" in request + assert request["tool_choice"] == {"type": "function", "name": "test_tool"} + + +def test_format_request_message_content_image_size_limit(): + """Test that oversized images raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + content = {"image": {"format": "png", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_content_document_size_limit(): + """Test that oversized documents raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + content = {"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_tool_message_image_size_limit(): + """Test that oversized images in tool results raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"image": {"format": "png", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_format_request_tool_message_document_size_limit(): + """Test that oversized documents in tool results raise ValueError.""" + oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_openai_version_check(): + """Test that module import fails with old OpenAI SDK version.""" + import importlib + + import strands.models.openai_responses as openai_responses_module + + def mock_old_version(package_name: str) -> str: + if package_name == "openai": + return "1.99.0" + from importlib.metadata import version + + return version(package_name) + + def mock_valid_version(package_name: str) -> str: + if package_name == "openai": + return "2.0.0" + from importlib.metadata import version + + return version(package_name) + + with unittest.mock.patch("importlib.metadata.version", mock_old_version): + with pytest.raises(ImportError, match="OpenAIResponsesModel requires openai>=2.0.0"): + importlib.reload(openai_responses_module) + + # Reload with valid version to restore module state + with unittest.mock.patch("importlib.metadata.version", mock_valid_version): + importlib.reload(openai_responses_module) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index ab8551391..15161b9cb 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -18,6 +18,13 @@ from strands.models.openai import OpenAIModel from strands.models.writer import WriterModel +try: + from strands.models.openai_responses import OpenAIResponsesModel + + _openai_responses_available = True +except ImportError: + _openai_responses_available = False + class ProviderInfo: """Provider-based info for providers that require an APIKey via environment variables.""" @@ -118,6 +125,19 @@ def __init__(self): }, ), ) +if _openai_responses_available: + openai_responses = ProviderInfo( + id="openai_responses", + environment_variable="OPENAI_API_KEY", + factory=lambda: OpenAIResponsesModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ), + ) +else: + openai_responses = None writer = ProviderInfo( id="writer", environment_variable="WRITER_API_KEY", @@ -141,13 +161,18 @@ def __init__(self): all_providers = [ - bedrock, - anthropic, - cohere, - gemini, - llama, - litellm, - mistral, - openai, - writer, + provider + for provider in [ + bedrock, + anthropic, + cohere, + gemini, + llama, + litellm, + mistral, + openai, + openai_responses, + writer, + ] + if provider is not None ] diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index d31ef3333..bccf2d82b 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -9,15 +9,27 @@ from strands.models.openai import OpenAIModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from tests_integ.models import providers +from tests_integ.models.providers import _openai_responses_available + +if _openai_responses_available: + from strands.models.openai_responses import OpenAIResponsesModel # these tests only run if we have the openai api key pytestmark = providers.openai.mark -@pytest.fixture -def model(): - return OpenAIModel( - model_id="gpt-4o", +def _model_params(): + params = [(OpenAIModel, "gpt-4o")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o")) + return params + + +@pytest.fixture(params=_model_params()) +def model(request): + model_class, model_id = request.param + return model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -73,7 +85,7 @@ def test_image_path(request): return request.config.rootpath / "tests_integ" / "test_image.png" -def test_agent_invoke(agent): +def test_agent_invoke(agent, model): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -81,7 +93,7 @@ def test_agent_invoke(agent): @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(agent, model): result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -89,7 +101,7 @@ async def test_agent_invoke_async(agent): @pytest.mark.asyncio -async def test_agent_stream_async(agent): +async def test_agent_stream_async(agent, model): stream = agent.stream_async("What is the time and weather in New York?") async for event in stream: _ = event @@ -170,15 +182,23 @@ def tool_with_image_return(): agent("Run the the tool and analyze the image") -def test_context_window_overflow_integration(): +def _mini_model_params(): + params = [(OpenAIModel, "gpt-4o-mini-2024-07-18")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o-mini-2024-07-18")) + return params + + +@pytest.mark.parametrize("model_class,model_id", _mini_model_params()) +def test_context_window_overflow_integration(model_class, model_id): """Integration test for context window overflow with OpenAI. This test verifies that when a request exceeds the model's context window, the OpenAI model properly raises a ContextWindowOverflowException. """ # Use gpt-4o-mini which has a smaller context window to make this test more reliable - mini_model = OpenAIModel( - model_id="gpt-4o-mini-2024-07-18", + mini_model = model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -198,14 +218,27 @@ def test_context_window_overflow_integration(): agent(long_text) -def test_rate_limit_throttling_integration_no_retries(model): +def _rate_limit_params(): + params = [(OpenAIModel, "gpt-4o")] + if _openai_responses_available: + params.append((OpenAIResponsesModel, "gpt-4o")) + return params + + +@pytest.mark.parametrize("model_class,model_id", _rate_limit_params()) +def test_rate_limit_throttling_integration_no_retries(model_class, model_id): """Integration test for rate limit handling with retries disabled. This test verifies that when a request exceeds OpenAI's rate limits, the model properly raises a ModelThrottledException. We disable retries to avoid waiting for the exponential backoff during testing. """ - # Patch the event loop constants to disable retries for this test + model = model_class( + model_id=model_id, + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) agent = Agent(model=model, retry_strategy=ModelRetryStrategy(max_attempts=1)) # Create a message that's very long to trigger token-per-minute rate limits From 32caa8939f493de4d9144c214f0d54206de8b080 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung <51226101+kevmyung@users.noreply.github.com> Date: Thu, 5 Mar 2026 08:29:26 -0700 Subject: [PATCH 164/279] feat: add "anthropic" cache strategy to bypass model ID check (#1808) --- src/strands/models/bedrock.py | 27 +++++++++------ src/strands/models/model.py | 5 +-- tests/strands/models/test_bedrock.py | 50 ++++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a48d7229..3fa907995 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -178,13 +178,15 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) @property - def _supports_caching(self) -> bool: - """Whether this model supports prompt caching. + def _cache_strategy(self) -> str | None: + """The cache strategy for this model based on its model ID. - Returns True for Claude models on Bedrock. + Returns the appropriate cache strategy name, or None if automatic caching is not supported for this model. """ model_id = self.config.get("model_id", "").lower() - return "claude" in model_id or "anthropic" in model_id + if "claude" in model_id or "anthropic" in model_id: + return "anthropic" + return None @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore @@ -459,14 +461,17 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: # Inject cache point into cleaned_messages (not original messages) if cache_config is set cache_config = self.config.get("cache_config") - if cache_config and cache_config.strategy == "auto": - if self._supports_caching: + if cache_config: + strategy: str | None = cache_config.strategy + if strategy == "auto": + strategy = self._cache_strategy + if not strategy: + logger.warning( + "model_id=<%s> | cache_config is enabled but this model does not support automatic caching", + self.config.get("model_id"), + ) + if strategy == "anthropic": self._inject_cache_point(cleaned_messages) - else: - logger.warning( - "model_id=<%s> | cache_config is enabled but this model does not support caching", - self.config.get("model_id"), - ) return cleaned_messages diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 550ee22e9..9d83a72eb 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -23,10 +23,11 @@ class CacheConfig: Attributes: strategy: Caching strategy to use. - - "auto": Automatically inject cachePoint at optimal positions + - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage + - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check """ - strategy: Literal["auto"] = "auto" + strategy: Literal["auto", "anthropic"] = "auto" class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 9dae16be7..66fe8ab00 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2582,19 +2582,19 @@ async def test_format_request_with_guardrail_multiple_tool_results_same_message( assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "Question requiring multiple tools" -def test_supports_caching_true_for_claude(bedrock_client): - """Test that supports_caching returns True for Claude models.""" +def test_cache_strategy_anthropic_for_claude(bedrock_client): + """Test that _cache_strategy returns 'anthropic' for Claude models.""" model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") - assert model._supports_caching is True + assert model._cache_strategy == "anthropic" model2 = BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0") - assert model2._supports_caching is True + assert model2._cache_strategy == "anthropic" -def test_supports_caching_false_for_non_claude(bedrock_client): - """Test that supports_caching returns False for non-Claude models.""" +def test_cache_strategy_none_for_non_claude(bedrock_client): + """Test that _cache_strategy returns None for unsupported models.""" model = BedrockModel(model_id="amazon.nova-pro-v1:0") - assert model._supports_caching is False + assert model._cache_strategy is None def test_inject_cache_point_adds_to_last_assistant(bedrock_client): @@ -2693,6 +2693,42 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client): assert "cachePoint" in cleaned_messages[3]["content"][-1] +def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client): + """Test that anthropic strategy injects cache point without model support check.""" + model = BedrockModel( + model_id="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/a1b2c3d4e5f6", + cache_config=CacheConfig(strategy="anthropic"), + ) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[1]["content"]) == 2 + assert "cachePoint" in formatted[1]["content"][-1] + assert formatted[1]["content"][-1]["cachePoint"]["type"] == "default" + + +def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client): + """Test that auto strategy resolves to anthropic strategy for Claude models.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + formatted = model._format_bedrock_messages(messages) + + assert len(formatted[1]["content"]) == 2 + assert "cachePoint" in formatted[1]["content"][-1] + + def test_find_last_user_text_message_index_no_user_messages(bedrock_client): """Test _find_last_user_text_message_index returns None when no user text messages exist.""" model = BedrockModel(model_id="test-model") From 12fd856c7920e28dd9f967a1ab190023e285c5f5 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Thu, 5 Mar 2026 13:13:38 -0800 Subject: [PATCH 165/279] feat: serialize tool results as JSON when possible (#1752) --- src/strands/tools/decorator.py | 19 +++- tests/strands/tools/test_decorator.py | 141 ++++++++++++++++++++++++-- 2 files changed, 149 insertions(+), 11 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 70552d6ba..0f91349d2 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -43,6 +43,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import asyncio import functools import inspect +import json import logging from collections.abc import Callable from typing import ( @@ -61,6 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import docstring_parser from pydantic import BaseModel, Field, create_model from pydantic.fields import FieldInfo +from pydantic_core import PydanticSerializationError from typing_extensions import override from ..interrupt import InterruptException @@ -644,12 +646,25 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception | return ToolResultEvent(cast(ToolResult, result), exception=exception) else: # Wrap any other return value in the standard format - # Always include at least one content item for consistency + # Serialize to JSON for consistent, parseable output (except strings) + if isinstance(result, str): + text = result + elif isinstance(result, BaseModel): + try: + text = result.model_dump_json() + except PydanticSerializationError: + text = str(result) + else: + try: + text = json.dumps(result) + except (TypeError, ValueError): + text = str(result) + return ToolResultEvent( { "toolUseId": tool_use_d, "status": "success", - "content": [{"text": str(result)}], + "content": [{"text": text}], }, exception=exception, ) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index f3d6eda02..cc1158983 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -136,7 +136,7 @@ def identity(a: int, agent: dict = None): tru_events = await alist(stream) exp_events = [ - ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}) + ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": '[2, {"state": 1}]'}]}) ] assert tru_events == exp_events @@ -595,12 +595,12 @@ def none_return_tool(param: str) -> None: assert result["tool_result"]["status"] == "success" assert result["tool_result"]["content"][0]["text"] == "Result: test" - # Test None return - should still create valid ToolResult with "None" text + # Test None return - should still create valid ToolResult with "null" stream = none_return_tool.stream(tool_use, {}) result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -861,7 +861,7 @@ def int_return_tool(param: str) -> int: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" # Define tool with Union return type @strands.tool @@ -884,10 +884,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert ( - "{'key': 'value'}" in result["tool_result"]["content"][0]["text"] - or '{"key": "value"}' in result["tool_result"]["content"][0]["text"] - ) + assert result["tool_result"]["content"][0]["text"] == '{"key": "value"}' tool_use = {"toolUseId": "test-id", "input": {"param": "str"}} stream = union_return_tool.stream(tool_use, {}) @@ -901,7 +898,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None: result = (await alist(stream))[-1] assert result["tool_result"]["status"] == "success" - assert result["tool_result"]["content"][0]["text"] == "None" + assert result["tool_result"]["content"][0]["text"] == "null" @pytest.mark.asyncio @@ -992,6 +989,132 @@ def custom_result_tool(param: str) -> dict[str, Any]: assert result["tool_result"]["content"][1]["type"] == "markdown" +@pytest.mark.asyncio +async def test_tool_result_json_serialization_dict(alist): + """Test that dict results are serialized as JSON.""" + + @strands.tool + def dict_tool() -> dict: + """Returns a dict.""" + return {"key": "value", "number": 42} + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = dict_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"key": "value", "number": 42}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_list(alist): + """Test that list results are serialized as JSON.""" + + @strands.tool + def list_tool() -> list: + """Returns a list.""" + return [1, "two", {"three": 3}] + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = list_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '[1, "two", {"three": 3}]' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic(alist): + """Test that Pydantic model results are serialized as JSON.""" + from pydantic import BaseModel + + class MyModel(BaseModel): + name: str + count: int + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model.""" + return MyModel(name="test", count=5) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == '{"name":"test","count":5}' + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_pydantic_non_serializable(alist): + """Test that Pydantic models with non-serializable fields fall back to str().""" + from pydantic import BaseModel + + class NonSerializable: + def __repr__(self): + return "NonSerializable()" + + class MyModel(BaseModel): + model_config = {"arbitrary_types_allowed": True} + data: NonSerializable + + @strands.tool + def pydantic_tool() -> MyModel: + """Returns a Pydantic model with non-serializable field.""" + return MyModel(data=NonSerializable()) + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = pydantic_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "data=NonSerializable()" + + +@pytest.mark.asyncio +async def test_tool_result_json_serialization_non_serializable(alist): + """Test that non-JSON-serializable results fall back to str().""" + + class CustomClass: + def __str__(self): + return "custom_str_repr" + + @strands.tool + def custom_tool() -> Any: + """Returns a non-serializable object.""" + return CustomClass() + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = custom_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "custom_str_repr" + + +@pytest.mark.asyncio +async def test_tool_result_string_not_json_encoded(alist): + """Test that string results are NOT JSON-encoded (no extra quotes).""" + + @strands.tool + def string_tool() -> str: + """Returns a string.""" + return "hello world" + + tool_use = {"toolUseId": "test-id", "input": {}} + stream = string_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + text = result["tool_result"]["content"][0]["text"] + + assert text == "hello world" + + def test_docstring_parsing(): """Test that function docstring is correctly parsed into tool spec.""" From a7d19cca20f94e8ff3b86ae7c738f1a84b740623 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 6 Mar 2026 10:10:19 -0500 Subject: [PATCH 166/279] fix: summary manager using structured output (#1805) --- .../summarizing_conversation_manager.py | 8 + .../test_summarizing_conversation_manager.py | 33 ++ tests/strands/telemetry/test_tracer.py | 292 ++++++++++-------- ...rizing_conversation_manager_integration.py | 66 ++++ 4 files changed, 269 insertions(+), 130 deletions(-) diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 12b04dcea..abd4d08b5 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -220,8 +220,14 @@ def _generate_summary_with_agent(self, messages: list[Message]) -> Message: original_system_prompt = summarization_agent.system_prompt original_messages = summarization_agent.messages.copy() original_tool_registry = summarization_agent.tool_registry + original_structured_output_model = getattr(summarization_agent, "_default_structured_output_model", None) try: + # Disable structured output for summarization. Summaries are plain text and + # structured output adds toolUse blocks that are invalid in user messages. + if hasattr(summarization_agent, "_default_structured_output_model"): + summarization_agent._default_structured_output_model = None + # Add no-op tool if agent has no tools to satisfy tool spec requirement if not summarization_agent.tool_names: tool_registry = ToolRegistry() @@ -237,6 +243,8 @@ def _generate_summary_with_agent(self, messages: list[Message]) -> Message: summarization_agent.system_prompt = original_system_prompt summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry + if hasattr(summarization_agent, "_default_structured_output_model"): + summarization_agent._default_structured_output_model = original_structured_output_model # ------------------------------------------------------------------ # Path 2 – default case: call model.stream() directly diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index b105eba86..c49c69de6 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -50,6 +50,7 @@ def __init__(self, summary_response="This is a summary of the conversation."): self.call_tracker = Mock() self.tool_registry = Mock() self.tool_names = [] + self._default_structured_output_model = None def __call__(self, prompt): """Mock agent call that returns a summary.""" @@ -769,3 +770,35 @@ def test_summarizing_conversation_manager_generate_summary_with_tools_agent_path manager._generate_summary(messages, parent_agent) mock_registry.register_tool.assert_not_called() + + +def test_generate_summary_disables_structured_output_on_summarization_agent(): + """Test that structured output is disabled during summarization to avoid toolUse in user messages. + + When a summarization agent has structured_output_model configured, the response contains toolUse blocks. + Since the summary is converted to a user message, toolUse blocks would violate the model API constraint + that user messages cannot contain tool uses. The fix disables structured output during summarization. + """ + summary_agent = create_mock_agent() + structured_output_model = Mock() + summary_agent._default_structured_output_model = structured_output_model + + original_call = summary_agent.__class__.__call__ + observed_values = [] + + def tracking_call(self, prompt): + observed_values.append(self._default_structured_output_model) + return original_call(self, prompt) + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + + manager = SummarizingConversationManager(summarization_agent=summary_agent) + + with patch.object(MockAgent, "__call__", tracking_call): + manager._generate_summary(messages, create_mock_agent()) + + assert observed_values == [None], "structured output should be disabled during summarization" + assert summary_agent._default_structured_output_model is structured_output_model, "should be restored after" diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index da7f010e2..50c0cc9b9 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -148,14 +148,16 @@ def test_start_model_invoke_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "chat", - "gen_ai.system": "strands-agents", - "custom_key": "custom_value", - "user_id": "12345", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.system": "strands-agents", + "custom_key": "custom_value", + "user_id": "12345", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } + ) mock_span.add_event.assert_called_with( "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} ) @@ -188,13 +190,15 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "chat", - "gen_ai.provider.name": "strands-agents", - "gen_ai.request.model": model_id, - "agent_name": "TestAgent", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "chat", + "gen_ai.provider.name": "strands-agents", + "gen_ai.request.model": model_id, + "agent_name": "TestAgent", + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -232,15 +236,17 @@ def test_end_model_invoke_span(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, @@ -259,15 +265,17 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -300,15 +308,17 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.tool.name": "test-tool", - "gen_ai.system": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - "session_id": "abc123", - "environment": "production", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.system": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + "session_id": "abc123", + "environment": "production", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -331,13 +341,15 @@ def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.tool.name": "test-tool", - "gen_ai.provider.name": "strands-agents", - "gen_ai.operation.name": "execute_tool", - "gen_ai.tool.call.id": "123", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.tool.name": "test-tool", + "gen_ai.provider.name": "strands-agents", + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", attributes={ @@ -377,14 +389,16 @@ def test_start_swarm_call_span_with_string_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - "workflow_id": "wf-789", - "priority": "high", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + "workflow_id": "wf-789", + "priority": "high", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"}) assert span is not None @@ -404,12 +418,14 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "swarm", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'} ) @@ -460,12 +476,14 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_swarm", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "swarm", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_swarm", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "swarm", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -528,13 +546,15 @@ def test_start_graph_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "execute_tool", - "gen_ai.system": "strands-agents", - "gen_ai.tool.name": "test-tool", - "gen_ai.tool.call.id": "123", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_tool", + "gen_ai.system": "strands-agents", + "gen_ai.tool.name": "test-tool", + "gen_ai.tool.call.id": "123", + } + ) mock_span.add_event.assert_any_call( "gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"} ) @@ -608,12 +628,14 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - - mock_span.set_attributes.assert_called_once_with({ - "event_loop.cycle_id": "cycle-123", - "request_id": "req-456", - "trace_level": "debug", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "event_loop.cycle_id": "cycle-123", + "request_id": "req-456", + "trace_level": "debug", + } + ) mock_span.add_event.assert_any_call( "gen_ai.user.message", attributes={"content": json.dumps([{"text": "Hello"}])} ) @@ -637,7 +659,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - + mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", @@ -731,14 +753,16 @@ def test_start_agent_span(mock_tracer): assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.INTERNAL - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_agent", - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.system": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": json.dumps(content)}) assert span is not None @@ -768,15 +792,17 @@ def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" - - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.operation.name": "invoke_agent", - "gen_ai.provider.name": "strands-agents", - "gen_ai.agent.name": "WeatherAgent", - "gen_ai.request.model": model_id, - "gen_ai.agent.tools": json.dumps(tools), - "custom_attr": "value", - }) + + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "invoke_agent", + "gen_ai.provider.name": "strands-agents", + "gen_ai.agent.name": "WeatherAgent", + "gen_ai.request.model": model_id, + "gen_ai.agent.tools": json.dumps(tools), + "custom_attr": "value", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ @@ -919,17 +945,19 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): tracer.end_model_invoke_span(mock_span, message, usage, metrics, stop_reason) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.usage.cache_read_input_tokens": 5, - "gen_ai.usage.cache_write_input_tokens": 3, - "gen_ai.server.request.duration": 10, - "gen_ai.server.time_to_first_token": 5, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.usage.cache_read_input_tokens": 5, + "gen_ai.usage.cache_write_input_tokens": 3, + "gen_ai.server.request.duration": 10, + "gen_ai.server.time_to_first_token": 5, + } + ) def test_end_agent_span_with_cache_metrics(mock_span): @@ -953,15 +981,17 @@ def test_end_agent_span_with_cache_metrics(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attributes.assert_called_once_with({ - "gen_ai.usage.prompt_tokens": 50, - "gen_ai.usage.input_tokens": 50, - "gen_ai.usage.completion_tokens": 100, - "gen_ai.usage.output_tokens": 100, - "gen_ai.usage.total_tokens": 150, - "gen_ai.usage.cache_read_input_tokens": 25, - "gen_ai.usage.cache_write_input_tokens": 10, - }) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.usage.prompt_tokens": 50, + "gen_ai.usage.input_tokens": 50, + "gen_ai.usage.completion_tokens": 100, + "gen_ai.usage.output_tokens": 100, + "gen_ai.usage.total_tokens": 150, + "gen_ai.usage.cache_read_input_tokens": 25, + "gen_ai.usage.cache_write_input_tokens": 10, + } + ) mock_span.set_status.assert_called_once_with(StatusCode.OK) mock_span.end.assert_called_once() @@ -1519,18 +1549,20 @@ def test_end_model_invoke_span_langfuse_adds_attributes(mock_span, monkeypatch): } ] ) - + assert mock_span.set_attributes.call_count == 2 mock_span.set_attributes.assert_any_call({"gen_ai.output.messages": expected_output}) - mock_span.set_attributes.assert_any_call({ - "gen_ai.usage.prompt_tokens": 10, - "gen_ai.usage.input_tokens": 10, - "gen_ai.usage.completion_tokens": 20, - "gen_ai.usage.output_tokens": 20, - "gen_ai.usage.total_tokens": 30, - "gen_ai.server.time_to_first_token": 10, - "gen_ai.server.request.duration": 20, - }) + mock_span.set_attributes.assert_any_call( + { + "gen_ai.usage.prompt_tokens": 10, + "gen_ai.usage.input_tokens": 10, + "gen_ai.usage.completion_tokens": 20, + "gen_ai.usage.output_tokens": 20, + "gen_ai.usage.total_tokens": 30, + "gen_ai.server.time_to_first_token": 10, + "gen_ai.server.request.duration": 20, + } + ) mock_span.add_event.assert_called_with( "gen_ai.client.inference.operation.details", diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index 91fb5b910..d6508edce 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -16,6 +16,7 @@ import os import pytest +from pydantic import BaseModel import strands from strands import Agent @@ -408,3 +409,68 @@ def test_summarization_with_tool_messages_and_no_tools(): summary = str(agent.messages[0]).lower() assert "12:00" in summary + + +def test_dedicated_summarization_agent_with_structured_output(model, summarization_model): + """Test that summarization works when the summarization agent has structured_output_model configured. + + When structured_output_model is set on the summarization agent, the response would contain toolUse + blocks. Since the summary is converted to a user message, those blocks would cause a + ValidationException. This test verifies that structured output is properly disabled during + summarization. + """ + + class SummaryOutput(BaseModel): + topics: list[str] + key_points: list[str] + + # Create a summarization agent with structured_output_model configured + summarization_agent = Agent( + model=summarization_model, + system_prompt="You are a conversation summarizer. Create concise, structured summaries.", + structured_output_model=SummaryOutput, + load_tools_from_directory=False, + ) + + agent = Agent( + model=model, + conversation_manager=SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + summarization_agent=summarization_agent, + ), + load_tools_from_directory=False, + ) + + # Build conversation history + agent.messages.extend( + [ + {"role": "user", "content": [{"text": "Tell me about Python programming."}]}, + {"role": "assistant", "content": [{"text": "Python is a high-level programming language."}]}, + {"role": "user", "content": [{"text": "What about its type system?"}]}, + {"role": "assistant", "content": [{"text": "Python uses dynamic typing with optional type hints."}]}, + {"role": "user", "content": [{"text": "How does async work in Python?"}]}, + {"role": "assistant", "content": [{"text": "Python uses asyncio with async/await syntax."}]}, + {"role": "user", "content": [{"text": "What about decorators?"}]}, + {"role": "assistant", "content": [{"text": "Decorators are functions that modify other functions."}]}, + ] + ) + + original_length = len(agent.messages) + agent.conversation_manager.reduce_context(agent) + + assert len(agent.messages) < original_length + + summary_message = agent.messages[0] + assert summary_message["role"] == "user" + + # Summary should contain only valid user message content (no toolUse blocks) + for content_block in summary_message["content"]: + assert "toolUse" not in content_block, "Summary user message should not contain toolUse blocks" + + # Should have text content + assert any("text" in cb for cb in summary_message["content"]) + + # Invoke the agent with the summarized messages to verify the provider accepts them + result = agent("Thanks for the overview!") + assert result.message["role"] == "assistant" From 316f54edc993863b4b099c903f58d9e60c9fbc6b Mon Sep 17 00:00:00 2001 From: Shotaro Kataoka <42331656+ShotaroKataoka@users.noreply.github.com> Date: Sat, 7 Mar 2026 00:50:03 +0900 Subject: [PATCH 167/279] feat(mcp): expose server instructions from InitializeResult on MCPClient (#1814) --- src/strands/tools/mcp/mcp_client.py | 5 ++++- tests/strands/tools/mcp/conftest.py | 4 +++- tests/strands/tools/mcp/test_mcp_client.py | 14 ++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index f064f7def..51a627c7c 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -153,6 +153,7 @@ def __init__( self._background_thread_event_loop: AbstractEventLoop | None = None self._loaded_tools: list[MCPAgentTool] | None = None self._tool_provider_started = False + self.server_instructions: str | None = None self._consumers: set[Any] = set() # Task support configuration and caching @@ -732,9 +733,11 @@ async def _async_background_thread(self) -> None: elicitation_callback=self._elicitation_callback, ) as session: self._log_debug_with_thread("initializing MCP session") - await session.initialize() + init_result = await session.initialize() self._log_debug_with_thread("session initialized successfully") + # Store server instructions from InitializeResult for Host applications + self.server_instructions = init_result.instructions # Store the session for use while we await the close event self._background_thread_session = session diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py index 0cfce470a..d0ac46bdc 100644 --- a/tests/strands/tools/mcp/conftest.py +++ b/tests/strands/tools/mcp/conftest.py @@ -26,7 +26,9 @@ def mock_transport(): def mock_session(): """Create a mock MCP session.""" mock_session = AsyncMock() - mock_session.initialize = AsyncMock() + mock_init_result = MagicMock() + mock_init_result.instructions = None + mock_session.initialize = AsyncMock(return_value=mock_init_result) # Default: no task support (get_server_capabilities is sync, not async!) mock_session.get_server_capabilities = MagicMock(return_value=None) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index e477c64d5..5eedd1e33 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -50,6 +50,20 @@ def test_mcp_client_context_manager(mock_transport, mock_session): assert client._background_thread is None +def test_server_instructions_default(mock_transport, mock_session): + """Test that server_instructions defaults to None when server returns None.""" + mock_session.initialize.return_value.instructions = None + with MCPClient(mock_transport["transport_callable"]) as client: + assert client.server_instructions is None + + +def test_server_instructions_from_server(mock_transport, mock_session): + """Test that server_instructions is populated from InitializeResult.""" + mock_session.initialize.return_value.instructions = "Use tool A before tool B." + with MCPClient(mock_transport["transport_callable"]) as client: + assert client.server_instructions == "Use tool A before tool B." + + def test_list_tools_sync(mock_transport, mock_session): """Test that list_tools_sync correctly retrieves and adapts tools.""" mock_tool = MCPTool(name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}) From 697e55c12fa1530b2d1d0272ea4dba77c1cec5b5 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 6 Mar 2026 10:52:10 -0500 Subject: [PATCH 168/279] fix: added LANGFUSE_BASE_URL check for additinoal attribute (#1826) --- src/strands/telemetry/tracer.py | 5 +-- tests/strands/telemetry/test_tracer.py | 47 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 80fb86c40..0471a7fcc 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -117,8 +117,9 @@ def is_langfuse(self) -> bool: Returns: True if Langfuse is the OTLP endpoint, False otherwise. """ - return "langfuse" in os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") or "langfuse" in os.getenv( - "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "" + return any( + "langfuse" in os.getenv(var, "") + for var in ("OTEL_EXPORTER_OTLP_ENDPOINT", "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "LANGFUSE_BASE_URL") ) def _start_span( diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 50c0cc9b9..410db0c0c 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1604,3 +1604,50 @@ def test_end_model_invoke_span_non_langfuse_no_extra_attributes(mock_span, monke "gen_ai.client.inference.operation.details", attributes={"gen_ai.output.messages": expected_output}, ) + + +class TestIsLangfuse: + """Tests for the is_langfuse property.""" + + def test_is_langfuse_with_otel_exporter_otlp_endpoint(self, monkeypatch): + """Test is_langfuse returns True when OTEL_EXPORTER_OTLP_ENDPOINT contains langfuse.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://us.cloud.langfuse.com") + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_with_otel_exporter_otlp_traces_endpoint(self, monkeypatch): + """Test is_langfuse returns True when OTEL_EXPORTER_OTLP_TRACES_ENDPOINT contains langfuse.""" + monkeypatch.setenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "https://us.cloud.langfuse.com/api/public/otel/v1/traces" + ) + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_with_langfuse_base_url(self, monkeypatch): + """Test is_langfuse returns True when LANGFUSE_BASE_URL contains langfuse.""" + monkeypatch.setenv("LANGFUSE_BASE_URL", "https://us.cloud.langfuse.com") + tracer = Tracer() + assert tracer.is_langfuse is True + + def test_is_langfuse_false_when_no_langfuse_env_vars(self, monkeypatch): + """Test is_langfuse returns False when no Langfuse-related env vars are set.""" + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False + + def test_is_langfuse_false_with_non_langfuse_endpoint(self, monkeypatch): + """Test is_langfuse returns False when endpoint is not Langfuse.""" + monkeypatch.setenv("OTEL_EXPORTER_OTLP_ENDPOINT", "https://api.honeycomb.io") + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False + + def test_is_langfuse_false_with_non_langfuse_base_url(self, monkeypatch): + """Test is_langfuse returns False when LANGFUSE_BASE_URL doesn't contain langfuse.""" + monkeypatch.setenv("LANGFUSE_BASE_URL", "https://some-other-service.com") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) + tracer = Tracer() + assert tracer.is_langfuse is False From 2d766c4e6974732590439a79ab10c4f07f2d1806 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 6 Mar 2026 12:36:13 -0500 Subject: [PATCH 169/279] feat(session): add dirty flag to skip unnecessary agent state persistence (#1803) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/interrupt.py | 22 +- .../session/repository_session_manager.py | 55 +++++ src/strands/types/json_dict.py | 15 ++ .../test_repository_session_manager.py | 216 ++++++++++++++++++ tests/strands/test_interrupt.py | 64 ++++++ tests/strands/types/test_json_dict.py | 65 ++++++ 6 files changed, 436 insertions(+), 1 deletion(-) diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index 85997c9be..7d02b50ff 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -52,10 +52,12 @@ class _InterruptState: interrupts: dict[str, Interrupt] = field(default_factory=dict) context: dict[str, Any] = field(default_factory=dict) activated: bool = False + _version: int = field(default=0, compare=False, repr=False) def activate(self) -> None: """Activate the interrupt state.""" self.activated = True + self._version += 1 def deactivate(self) -> None: """Deacitvate the interrupt state. @@ -65,6 +67,7 @@ def deactivate(self) -> None: self.interrupts = {} self.context = {} self.activated = False + self._version += 1 def resume(self, prompt: "AgentInput") -> None: """Configure the interrupt state if resuming from an interrupt event. @@ -100,10 +103,27 @@ def resume(self, prompt: "AgentInput") -> None: self.interrupts[interrupt_id].response = interrupt_response self.context["responses"] = contents + self._version += 1 + + def _get_version(self) -> int: + """Get the current version number of the interrupt state. + + The version is incremented each time activate(), deactivate(), or resume() is called. + Consumers can compare versions to detect changes without requiring + explicit dirty flag clearing. + + Returns: + The current version number. + """ + return self._version def to_dict(self) -> dict[str, Any]: """Serialize to dict for session management.""" - return asdict(self) + return { + "interrupts": {k: v.to_dict() for k, v in self.interrupts.items()}, + "context": self.context, + "activated": self.activated, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index d23c4a94f..dd3562289 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,5 +1,6 @@ """Repository session manager implementation.""" +import copy import logging from typing import TYPE_CHECKING, Any @@ -59,6 +60,9 @@ def __init__( # Keep track of the latest message of each agent in case we need to redact it. self._latest_agent_message: dict[str, SessionMessage | None] = {} + # Track the previously synced internal state for each agent to detect changes. + self._last_synced_internal_state: dict[str, dict[str, Any]] = {} + def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. @@ -95,15 +99,66 @@ def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwarg def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: """Serialize and update the agent into the session repository. + Only updates the agent if state has been modified or internal state has changed. + This optimization reduces unnecessary I/O operations when the agent processes + messages without modifying its state. + Args: agent: Agent to sync to the session. **kwargs: Additional keyword arguments for future extensibility. """ + # Get current versions and conversation manager state + current_state_version = agent.state._get_version() + current_interrupt_state_version = agent._interrupt_state._get_version() + current_conversation_manager_state = agent.conversation_manager.get_state() + + # Check if we have a previous state to compare against + last_synced = self._last_synced_internal_state.get(agent.agent_id) + + # Determine if we need to update by comparing versions + if last_synced is None: + # First sync for this agent - always update + state_changed = True + internal_state_changed = True + conversation_manager_state_changed = True + else: + state_changed = current_state_version != last_synced.get("state_version") + internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version") + conversation_manager_state_changed = ( + current_conversation_manager_state != last_synced.get("conversation_manager_state") + ) + + if not state_changed and not internal_state_changed and not conversation_manager_state_changed: + logger.debug( + "agent_id=<%s> | session_id=<%s> | skipping sync, no changes detected", + agent.agent_id, + self.session_id, + ) + return + + logger.debug( + "agent_id=<%s> | session_id=<%s> | state_changed=<%s>, internal_state_changed=<%s>, " + "conversation_manager_state_changed=<%s> | syncing agent", + agent.agent_id, + self.session_id, + state_changed, + internal_state_changed, + conversation_manager_state_changed, + ) + + # Perform the update self.session_repository.update_agent( self.session_id, SessionAgent.from_agent(agent), ) + # Update tracked versions after successful sync + self._last_synced_internal_state[agent.agent_id] = { + "state_version": current_state_version, + "interrupt_state_version": current_interrupt_state_version, + "conversation_manager_state": copy.deepcopy(current_conversation_manager_state), + } + def initialize(self, agent: "Agent", **kwargs: Any) -> None: """Initialize an agent with a session. diff --git a/src/strands/types/json_dict.py b/src/strands/types/json_dict.py index a8636ab10..dc6ae6565 100644 --- a/src/strands/types/json_dict.py +++ b/src/strands/types/json_dict.py @@ -15,6 +15,7 @@ class JSONSerializableDict: def __init__(self, initial_state: dict[str, Any] | None = None): """Initialize JSONSerializableDict.""" self._data: dict[str, Any] + self._version: int = 0 if initial_state: self._validate_json_serializable(initial_state) self._data = copy.deepcopy(initial_state) @@ -34,6 +35,7 @@ def set(self, key: str, value: Any) -> None: self._validate_key(key) self._validate_json_serializable(value) self._data[key] = copy.deepcopy(value) + self._version += 1 def get(self, key: str | None = None) -> Any: """Get a value or entire data. @@ -57,6 +59,19 @@ def delete(self, key: str) -> None: """ self._validate_key(key) self._data.pop(key, None) + self._version += 1 + + def _get_version(self) -> int: + """Get the current version number of the store. + + The version is incremented each time set() or delete() is called. + Consumers can compare versions to detect changes without requiring + explicit dirty flag clearing. + + Returns: + The current version number. + """ + return self._version def _validate_key(self, key: str) -> None: """Validate that a key is valid. diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 22de9f964..f8f044a9b 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -595,3 +595,219 @@ def test_fix_broken_tool_use_does_not_affect_normal_conversations(session_manage # Should remain unchanged assert fixed_messages == messages + + +# ============================================================================ +# Conditional Sync Tests +# ============================================================================ + + +def test_sync_agent_skips_update_when_state_not_dirty_and_internal_state_unchanged(mock_repository): + """Test that sync_agent() skips update_agent() when state is not dirty and internal state unchanged.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync should update (to establish baseline) + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + # Clear tracking + update_agent_calls.clear() + + # Second sync without changes should skip update + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 0 + + +def test_sync_agent_calls_update_when_state_is_dirty(mock_repository): + """Test that sync_agent() calls update_agent() when agent.state is dirty.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify state (makes it dirty) + agent.state.set("key", "value") + + # Sync should call update_agent because state is dirty + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_calls_update_when_internal_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when internal state (interrupt_state) is dirty.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify internal state (activate interrupt state which sets dirty flag) + agent._interrupt_state.activate() + + # Sync should call update_agent because internal state is dirty + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_calls_update_when_conversation_manager_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when conversation manager state changed.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify conversation manager state + agent.conversation_manager.removed_message_count = 5 + + # Sync should call update_agent because conversation manager state changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_tracks_version_after_successful_sync(mock_repository): + """Test that sync_agent() tracks version after successful sync.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # First sync to establish baseline + session_manager.sync_agent(agent) + initial_version = agent.state._get_version() + + # Modify state (increments version) + agent.state.set("key", "value") + assert agent.state._get_version() == initial_version + 1 + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # Sync should update because version changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + # Second sync without changes should skip + update_agent_calls.clear() + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 0 + + +def test_sync_agent_retries_on_failure(mock_repository): + """Test that sync_agent() retries on next call if update_agent() fails.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # First sync to establish baseline + session_manager.sync_agent(agent) + + # Modify state (increments version) + agent.state.set("key", "value") + + # Make update_agent fail + def failing_update_agent(session_id, session_agent): + raise SessionException("Update failed") + + mock_repository.update_agent = failing_update_agent + + # Sync should fail + with pytest.raises(SessionException, match="Update failed"): + session_manager.sync_agent(agent) + + # Restore working update_agent + update_agent_calls = [] + original_update_agent = MockedSessionRepository.update_agent + + def tracking_update_agent(self, session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(self, session_id, session_agent) + + mock_repository.update_agent = lambda sid, sa: tracking_update_agent(mock_repository, sid, sa) + + # Retry should work because version wasn't updated on failure + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + +def test_sync_agent_first_sync_always_updates(mock_repository): + """Test that the first sync_agent() call always updates (no previous state to compare).""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync should always update (no previous state) + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 9c14cc63b..5c928cc81 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -127,3 +127,67 @@ def test_interrupt_resume_invalid_id(): exp_message = r"interrupt_id= \| no interrupt found" with pytest.raises(KeyError, match=exp_message): interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) + + +# ============================================================================ +# Version Tracking Tests +# ============================================================================ + + +def test_interrupt_state_version_is_zero_after_initialization(): + """Test that _get_version() returns 0 after initialization.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + +def test_interrupt_state_version_increments_after_activate(): + """Test that _get_version() increments after activate() is called.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + interrupt_state.activate() + assert interrupt_state._get_version() == 1 + + +def test_interrupt_state_version_increments_after_deactivate(): + """Test that _get_version() increments after deactivate() is called.""" + interrupt_state = _InterruptState(activated=True) + initial_version = interrupt_state._get_version() + + interrupt_state.deactivate() + assert interrupt_state._get_version() == initial_version + 1 + + +def test_interrupt_state_version_increments_after_resume(): + """Test that _get_version() increments after resume() is called.""" + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + initial_version = interrupt_state._get_version() + + prompt = [{"interruptResponse": {"interruptId": "test_id", "response": "test response"}}] + interrupt_state.resume(prompt) + assert interrupt_state._get_version() == initial_version + 1 + + +def test_interrupt_state_version_increments_independently(): + """Test that version increments independently for each operation.""" + interrupt_state = _InterruptState() + assert interrupt_state._get_version() == 0 + + interrupt_state.activate() + assert interrupt_state._get_version() == 1 + + interrupt_state.deactivate() + assert interrupt_state._get_version() == 2 + + +def test_interrupt_state_version_not_in_to_dict(): + """Test that _version is not included in to_dict() output.""" + interrupt_state = _InterruptState() + interrupt_state.activate() + + data = interrupt_state.to_dict() + assert "_version" not in data + assert "version" not in data diff --git a/tests/strands/types/test_json_dict.py b/tests/strands/types/test_json_dict.py index caa010bac..ad4f4660d 100644 --- a/tests/strands/types/test_json_dict.py +++ b/tests/strands/types/test_json_dict.py @@ -109,3 +109,68 @@ def test_initial_state(): assert state.get("key1") == "value1" assert state.get("key2") == "value2" assert state.get() == initial + + +# ============================================================================ +# Version Tracking Tests +# ============================================================================ + + +def test_version_is_zero_after_initialization(): + """Test that _get_version() returns 0 after initialization.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + +def test_version_is_zero_after_initialization_with_initial_state(): + """Test that _get_version() returns 0 when initialized with initial_state.""" + state = JSONSerializableDict(initial_state={"key": "value"}) + assert state._get_version() == 0 + + +def test_version_increments_after_set(): + """Test that _get_version() increments after set() is called.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + state.set("key", "value") + assert state._get_version() == 1 + + state.set("key2", "value2") + assert state._get_version() == 2 + + +def test_version_increments_after_delete(): + """Test that _get_version() increments after delete() is called.""" + state = JSONSerializableDict(initial_state={"key": "value"}) + assert state._get_version() == 0 + + state.delete("key") + assert state._get_version() == 1 + + +def test_version_increments_after_delete_nonexistent_key(): + """Test that _get_version() increments after delete() on nonexistent key.""" + state = JSONSerializableDict() + assert state._get_version() == 0 + + state.delete("nonexistent") + assert state._get_version() == 1 + + +def test_version_increments_independently(): + """Test that version increments independently for each operation.""" + state = JSONSerializableDict() + initial_version = state._get_version() + + state.set("key1", "value1") + version_after_first_set = state._get_version() + assert version_after_first_set == initial_version + 1 + + state.set("key2", "value2") + version_after_second_set = state._get_version() + assert version_after_second_set == version_after_first_set + 1 + + state.delete("key1") + version_after_delete = state._get_version() + assert version_after_delete == version_after_second_set + 1 From 98636aeeedb3ecf3953d9e7300c8cf92f011e717 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 9 Mar 2026 13:32:14 -0400 Subject: [PATCH 170/279] feat: add public tool_spec setter (#1822) --- src/strands/tools/decorator.py | 25 ++ src/strands/tools/tools.py | 25 ++ tests/strands/tools/test_tool_spec_setter.py | 253 +++++++++++++++++++ 3 files changed, 303 insertions(+) create mode 100644 tests/strands/tools/test_tool_spec_setter.py diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 0f91349d2..9207df9b8 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -543,6 +543,31 @@ def tool_spec(self) -> ToolSpec: """ return self._tool_spec + @tool_spec.setter + def tool_spec(self, value: ToolSpec) -> None: + """Set the tool specification. + + This allows runtime modification of the tool's schema, enabling dynamic + tool configurations based on feature flags or other runtime conditions. + + Args: + value: The new tool specification. + + Raises: + ValueError: If the spec fails structural validation (wrong name or + missing required field). + """ + if value.get("name") != self._tool_name: + raise ValueError( + f"cannot change tool name via tool_spec (expected '{self._tool_name}', got '{value.get('name')}')" + ) + + for field in ("description", "inputSchema"): + if field not in value: + raise ValueError(f"tool_spec must contain '{field}'") + + self._tool_spec = value + @property def tool_type(self) -> str: """Get the type of the tool. diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 39e2f3723..ccfeac323 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -197,6 +197,31 @@ def tool_spec(self) -> ToolSpec: """ return self._tool_spec + @tool_spec.setter + def tool_spec(self, value: ToolSpec) -> None: + """Set the tool specification. + + This allows runtime modification of the tool's schema, enabling dynamic + tool configurations based on feature flags or other runtime conditions. + + Args: + value: The new tool specification. + + Raises: + ValueError: If the spec fails structural validation (wrong name or + missing required field). + """ + if value.get("name") != self._tool_name: + raise ValueError( + f"cannot change tool name via tool_spec (expected '{self._tool_name}', got '{value.get('name')}')" + ) + + for field in ("description", "inputSchema"): + if field not in value: + raise ValueError(f"tool_spec must contain '{field}'") + + self._tool_spec = value + @property def supports_hot_reload(self) -> bool: """Check if this tool supports automatic reloading when modified. diff --git a/tests/strands/tools/test_tool_spec_setter.py b/tests/strands/tools/test_tool_spec_setter.py new file mode 100644 index 000000000..842146c72 --- /dev/null +++ b/tests/strands/tools/test_tool_spec_setter.py @@ -0,0 +1,253 @@ +"""Tests for tool_spec setter on DecoratedFunctionTool and PythonAgentTool.""" + +import pytest + +from strands.tools.decorator import tool +from strands.tools.tools import PythonAgentTool +from strands.types.tools import ToolSpec + + +class TestDecoratedFunctionToolSpecSetter: + """Tests for DecoratedFunctionTool.tool_spec setter.""" + + def test_set_tool_spec_replaces_spec(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + new_spec: ToolSpec = { + "name": "my_tool", + "description": "Updated tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "The query"}, + "limit": {"type": "integer", "description": "Max results"}, + }, + "required": ["query"], + } + }, + } + my_tool.tool_spec = new_spec + assert my_tool.tool_spec is new_spec + assert "limit" in my_tool.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_persists_across_reads(self): + @tool + def another_tool(x: int) -> int: + """Another test tool.""" + return x + + new_spec: ToolSpec = { + "name": "another_tool", + "description": "Modified", + "inputSchema": { + "json": { + "type": "object", + "properties": {"x": {"type": "integer"}, "y": {"type": "integer"}}, + "required": ["x"], + } + }, + } + another_tool.tool_spec = new_spec + assert another_tool.tool_spec["description"] == "Modified" + assert another_tool.tool_spec["description"] == "Modified" + + def test_add_property_via_setter(self): + @tool + def dynamic_tool(base: str) -> str: + """A dynamic tool.""" + return base + + spec = dynamic_tool.tool_spec.copy() + spec["inputSchema"] = dynamic_tool.tool_spec["inputSchema"].copy() + spec["inputSchema"]["json"] = dynamic_tool.tool_spec["inputSchema"]["json"].copy() + spec["inputSchema"]["json"]["properties"] = dynamic_tool.tool_spec["inputSchema"]["json"]["properties"].copy() + spec["inputSchema"]["json"]["properties"]["extra"] = { + "type": "string", + "description": "Extra param", + } + dynamic_tool.tool_spec = spec + assert "extra" in dynamic_tool.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_rejects_name_change(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "wrong_name", + "description": "Updated tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="cannot change tool name via tool_spec"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_description(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "my_tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="tool_spec must contain 'description'"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_input_schema(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bad_spec: ToolSpec = { + "name": "my_tool", + "description": "Updated tool", + } + with pytest.raises(ValueError, match="tool_spec must contain 'inputSchema'"): + my_tool.tool_spec = bad_spec + + def test_set_tool_spec_accepts_bare_input_schema(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + bare_spec: ToolSpec = { + "name": "my_tool", + "description": "Bare schema", + "inputSchema": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, + } + my_tool.tool_spec = bare_spec + assert my_tool.tool_spec is bare_spec + + def test_set_tool_spec_accepts_valid_spec(self): + @tool + def my_tool(query: str) -> str: + """A test tool.""" + return query + + valid_spec: ToolSpec = { + "name": "my_tool", + "description": "A valid updated spec", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + }, + } + my_tool.tool_spec = valid_spec + assert my_tool.tool_spec is valid_spec + + +class TestPythonAgentToolSpecSetter: + """Tests for PythonAgentTool.tool_spec setter.""" + + def _make_tool(self) -> PythonAgentTool: + def func(tool_use, **kwargs): + return {"status": "success", "content": [{"text": "ok"}], "toolUseId": tool_use["toolUseId"]} + + spec: ToolSpec = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + } + }, + } + return PythonAgentTool("test_tool", spec, func) + + def test_set_tool_spec(self): + t = self._make_tool() + new_spec: ToolSpec = { + "name": "test_tool", + "description": "Updated", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + "extra": {"type": "integer"}, + }, + "required": ["input"], + } + }, + } + t.tool_spec = new_spec + assert t.tool_spec is new_spec + assert "extra" in t.tool_spec["inputSchema"]["json"]["properties"] + + def test_set_tool_spec_persists(self): + t = self._make_tool() + new_spec: ToolSpec = { + "name": "test_tool", + "description": "Persisted", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + t.tool_spec = new_spec + assert t.tool_spec["description"] == "Persisted" + assert t.tool_spec["description"] == "Persisted" + + def test_set_tool_spec_rejects_name_change(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "wrong_name", + "description": "Updated", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="cannot change tool name via tool_spec"): + t.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_description(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "test_tool", + "inputSchema": {"json": {"type": "object", "properties": {}, "required": []}}, + } + with pytest.raises(ValueError, match="tool_spec must contain 'description'"): + t.tool_spec = bad_spec + + def test_set_tool_spec_rejects_missing_input_schema(self): + t = self._make_tool() + bad_spec: ToolSpec = { + "name": "test_tool", + "description": "Updated", + } + with pytest.raises(ValueError, match="tool_spec must contain 'inputSchema'"): + t.tool_spec = bad_spec + + def test_set_tool_spec_accepts_bare_input_schema(self): + t = self._make_tool() + bare_spec: ToolSpec = { + "name": "test_tool", + "description": "Bare schema", + "inputSchema": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]}, + } + t.tool_spec = bare_spec + assert t.tool_spec is bare_spec + + def test_set_tool_spec_accepts_valid_spec(self): + t = self._make_tool() + valid_spec: ToolSpec = { + "name": "test_tool", + "description": "A valid updated spec", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + } + }, + } + t.tool_spec = valid_spec + assert t.tool_spec is valid_spec From 73fe9cc18ede5b47c1467990aad67c991f51d985 Mon Sep 17 00:00:00 2001 From: Jay Goyani <135654128+jgoyani1@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:00:53 -0700 Subject: [PATCH 171/279] feat: add CancellationToken for graceful agent execution cancellation (#1772) --- src/strands/agent/agent.py | 37 +++ src/strands/event_loop/event_loop.py | 42 +++ src/strands/event_loop/streaming.py | 23 +- .../session/repository_session_manager.py | 4 +- src/strands/types/event_loop.py | 2 + .../strands/agent/test_agent_cancellation.py | 289 ++++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 3 + .../test_event_loop_structured_output.py | 2 + tests_integ/test_cancellation.py | 156 ++++++++++ 9 files changed, 554 insertions(+), 4 deletions(-) create mode 100644 tests/strands/agent/test_agent_cancellation.py create mode 100644 tests_integ/test_cancellation.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ebead3b7d..8f4167d9b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -240,6 +240,9 @@ def __init__( self.record_direct_tool_call = record_direct_tool_call self.load_tools_from_directory = load_tools_from_directory + # Create internal cancel signal for graceful cancellation using threading.Event + self._cancel_signal = threading.Event() + self.tool_registry = ToolRegistry() # Process tool list if provided @@ -327,6 +330,37 @@ def __init__( self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + def cancel(self) -> None: + """Cancel the currently running agent invocation. + + This method is thread-safe and can be called from any context + (e.g., another thread, web request handler, background task). + + The agent will stop gracefully at the next checkpoint: + - During model response streaming + - Before tool execution + + The agent will return a result with stop_reason="cancelled". + + Example: + ```python + agent = Agent(model=model) + + # Start agent in background + task = asyncio.create_task(agent.invoke_async("Hello")) + + # Cancel from another context + agent.cancel() + + result = await task + assert result.stop_reason == "cancelled" + ``` + + Note: + Multiple calls to cancel() are safe and idempotent. + """ + self._cancel_signal.set() + @property def system_prompt(self) -> str | None: """Get the system prompt as a string for backwards compatibility. @@ -756,6 +790,9 @@ async def stream_async( raise finally: + # Clear cancel signal to allow agent reuse after cancellation + self._cancel_signal.clear() + if self._invocation_lock.locked(): self._invocation_lock.release() diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3113ddb79..3b1e2d76a 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -336,6 +336,7 @@ async def _handle_model_execution( system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, invocation_state=invocation_state, + cancel_signal=agent._cancel_signal, ): yield event @@ -465,6 +466,47 @@ async def _handle_tool_execution( tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] interrupts = [] + + # Check for cancellation before tool execution + # Add tool_result for each tool_use to maintain valid conversation state + if agent._cancel_signal.is_set(): + logger.debug("tool_count=<%d> | cancellation detected before tool execution", len(tool_uses)) + + # Create cancellation tool_result for each tool_use to avoid invalid message state + # (tool_use without tool_result would be rejected on next invocation) + for tool_use in tool_uses: + cancel_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": "Tool execution cancelled"}], + } + tool_results.append(cancel_result) + + # Add tool results message to conversation if any tools were cancelled + cancelled_tool_result_message: Message | None = None + if tool_results: + _cancelled_msg: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + cancelled_tool_result_message = _cancelled_msg + agent.messages.append(_cancelled_msg) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=_cancelled_msg)) + yield ToolResultMessageEvent(message=_cancelled_msg) + + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "cancelled", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + ) + if cycle_span: + tracer.end_event_loop_cycle_span( + span=cycle_span, message=message, tool_result_message=cancelled_tool_result_message + ) + return + tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context ) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index b157f740e..b7d85ca30 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -2,6 +2,7 @@ import json import logging +import threading import time import warnings from collections.abc import AsyncGenerator, AsyncIterable @@ -368,13 +369,16 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non async def process_stream( - chunks: AsyncIterable[StreamEvent], start_time: float | None = None + chunks: AsyncIterable[StreamEvent], + start_time: float | None = None, + cancel_signal: threading.Event | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. start_time: Time when the model request is initiated + cancel_signal: Optional threading.Event to check for cancellation during streaming. Yields: The reason for stopping, the constructed message, and the usage metrics. @@ -395,6 +399,19 @@ async def process_stream( metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # Check for cancellation during stream processing + if cancel_signal and cancel_signal.is_set(): + logger.debug("cancellation detected during stream processing") + # Return cancelled stop reason with cancellation message + # The incomplete message in state["message"] is discarded and never added to agent.messages + yield ModelStopReason( + stop_reason="cancelled", + message={"role": "assistant", "content": [{"text": "Cancelled by user"}]}, + usage=usage, + metrics=metrics, + ) + return + # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() @@ -431,6 +448,7 @@ async def stream_messages( tool_choice: Any | None = None, system_prompt_content: list[SystemContentBlock] | None = None, invocation_state: dict[str, Any] | None = None, + cancel_signal: threading.Event | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -444,6 +462,7 @@ async def stream_messages( system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. + cancel_signal: Optional threading.Event to check for cancellation during streaming. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -463,5 +482,5 @@ async def stream_messages( invocation_state=invocation_state, ) - async for event in process_stream(chunks, start_time): + async for event in process_stream(chunks, start_time, cancel_signal): yield event diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index dd3562289..0d49c847d 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -124,8 +124,8 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: else: state_changed = current_state_version != last_synced.get("state_version") internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version") - conversation_manager_state_changed = ( - current_conversation_manager_state != last_synced.get("conversation_manager_state") + conversation_manager_state_changed = current_conversation_manager_state != last_synced.get( + "conversation_manager_state" ) if not state_changed and not internal_state_changed and not conversation_manager_state_changed: diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2a7ad344e..fca141327 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -37,6 +37,7 @@ class Metrics(TypedDict, total=False): StopReason = Literal[ + "cancelled", "content_filtered", "end_turn", "guardrail_intervened", @@ -47,6 +48,7 @@ class Metrics(TypedDict, total=False): ] """Reason for the model ending its response generation. +- "cancelled": Agent execution was cancelled via agent.cancel() - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py new file mode 100644 index 000000000..6af153f4a --- /dev/null +++ b/tests/strands/agent/test_agent_cancellation.py @@ -0,0 +1,289 @@ +"""Tests for agent cancellation functionality using agent.cancel() API.""" + +import asyncio +import threading + +import pytest + +from strands import Agent, tool +from strands.hooks import AfterModelCallEvent +from tests.fixtures.mocked_model_provider import MockedModelProvider + +# Default agent response for simple tests +DEFAULT_RESPONSE = { + "role": "assistant", + "content": [{"text": "Hello! How can I help you?"}], +} + + +@pytest.mark.asyncio +async def test_agent_cancel_before_invocation(): + """Test agent.cancel() before invocation starts. + + Verifies that calling cancel() before invoke_async() results in + immediate cancellation without any model calls. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) + + # Cancel before invocation + agent.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}]} + + +@pytest.mark.asyncio +async def test_agent_cancel_during_execution(): + """Test agent.cancel() during execution. + + Verifies that calling cancel() while the agent is running + stops execution at the next checkpoint. + """ + streaming_started = asyncio.Event() + cancel_ready = asyncio.Event() + + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + streaming_started.set() + # Block until cancel has been called + await cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + + async def cancel_when_ready(): + await streaming_started.wait() + agent.cancel() + cancel_ready.set() + + cancel_task = asyncio.create_task(cancel_when_ready()) + result = await agent.invoke_async("Hello") + await cancel_task + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_with_tools(): + """Test agent.cancel() during tool execution. + + Verifies that cancellation works correctly when tools are being executed. + Uses AfterModelCallEvent hook to cancel deterministically after model returns tool_use. + """ + tool_executed = [] + + @tool + def slow_tool(x: int) -> int: + """A tool for testing.""" + tool_executed.append(x) + return x * 2 + + tool_use_response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_1", + "name": "slow_tool", + "input": {"x": 5}, + } + } + ], + } + + agent = Agent( + model=MockedModelProvider([tool_use_response, DEFAULT_RESPONSE]), + tools=[slow_tool], + ) + + # Cancel deterministically after model returns tool_use + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async("Use the tool") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_idempotent(): + """Test that calling cancel() multiple times is safe. + + Verifies that multiple cancel() calls are idempotent and don't + cause any issues. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE])) + + # Cancel multiple times + agent.cancel() + agent.cancel() + agent.cancel() + + result = await agent.invoke_async("Hello") + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_from_thread(): + """Test agent.cancel() from another thread. + + Verifies thread-safety of the cancel() method when called + from a background thread. + """ + streaming_started = asyncio.Event() + cancel_ready = asyncio.Event() + loop = asyncio.get_running_loop() + + class DelayedModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + streaming_started.set() + await cancel_ready.wait() + async for event in super().stream(*args, **kwargs): + yield event + + agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE])) + + def cancel_from_thread(): + # Wait for streaming to start before cancelling + asyncio.run_coroutine_threadsafe(streaming_started.wait(), loop).result() + agent.cancel() + loop.call_soon_threadsafe(cancel_ready.set) + + thread = threading.Thread(target=cancel_from_thread) + thread.start() + + result = await agent.invoke_async("Hello") + thread.join() + + assert result.stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_streaming(): + """Test cancellation during streaming response. + + Verifies that cancellation works correctly when using + the streaming API (stream_async). + """ + chunks_yielded = asyncio.Event() + cancel_done = asyncio.Event() + + class SlowStreamingModelProvider(MockedModelProvider): + async def stream(self, *args, **kwargs): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + + for i in range(10): + yield {"contentBlockDelta": {"delta": {"text": f"chunk {i} "}}} + if i == 2: + # Signal after a few chunks so cancel can fire + chunks_yielded.set() + # Wait for cancel to complete before continuing + await cancel_done.wait() + + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + agent = Agent(model=SlowStreamingModelProvider([DEFAULT_RESPONSE])) + + async def cancel_after_chunks(): + await chunks_yielded.wait() + agent.cancel() + cancel_done.set() + + cancel_task = asyncio.create_task(cancel_after_chunks()) + + events = [] + async for event in agent.stream_async("Hello"): + events.append(event) + if event.get("result"): + break + + await cancel_task + + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" + + +@pytest.mark.asyncio +async def test_agent_cancel_before_tool_execution_adds_tool_results(): + """Test that cancelling before tool execution adds tool_result messages. + + Verifies that when cancellation occurs after model returns tool_use but before + tools execute, proper tool_result messages are added to maintain valid conversation state. + This prevents the "tool_use without tool_result" error on next invocation. + """ + + @tool + def calculator(x: int, y: int) -> int: + """Add two numbers.""" + return x + y + + tool_use_response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool_1", + "name": "calculator", + "input": {"x": 5, "y": 3}, + } + } + ], + } + + agent = Agent( + model=MockedModelProvider([tool_use_response, DEFAULT_RESPONSE]), + tools=[calculator], + ) + + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async("Calculate 5 + 3") + + assert result.stop_reason == "cancelled" + + # Should have: user message, assistant message with tool_use, user message with tool_result + assert len(agent.messages) == 3 + assert agent.messages[0]["role"] == "user" + assert agent.messages[1]["role"] == "assistant" + assert agent.messages[2]["role"] == "user" + + tool_result_content = agent.messages[2]["content"] + assert len(tool_result_content) == 1 + assert "toolResult" in tool_result_content[0] + + tool_result = tool_result_content[0]["toolResult"] + assert tool_result["toolUseId"] == "tool_1" + assert tool_result["status"] == "error" + assert "cancelled" in tool_result["content"][0]["text"].lower() + + +@pytest.mark.asyncio +async def test_agent_cancel_continue_after(): + """Test that agent is reusable after cancellation. + + Verifies that the cancel signal is cleared after an invocation completes, + allowing subsequent invocations to run normally. + """ + agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE])) + + agent.cancel() + result1 = await agent.invoke_async("Hello") + assert result1.stop_reason == "cancelled" + + # Second invocation should work normally + result2 = await agent.invoke_async("Hello again") + assert result2.stop_reason == "end_turn" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 8c6155e20..0cabeaeee 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,5 +1,6 @@ import asyncio import concurrent +import threading import unittest.mock from unittest.mock import ANY, AsyncMock, MagicMock, call, patch @@ -150,6 +151,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.hooks = hook_registry mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() mock.trace_attributes = {} mock.retry_strategy = ModelRetryStrategy() @@ -756,6 +758,7 @@ async def test_request_state_initialization(alist): mock_agent = MagicMock() # not setting this to False results in endless recursion mock_agent._interrupt_state.activated = False + mock_agent._cancel_signal = threading.Event() mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) mock_agent.hooks.invoke_callbacks_async = AsyncMock() diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 6f75d6083..ad792f52c 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -1,5 +1,6 @@ """Tests for structured output integration in the event loop.""" +import threading from unittest.mock import AsyncMock, Mock, patch import pytest @@ -52,6 +53,7 @@ def mock_agent(): agent._interrupt_state = Mock() agent._interrupt_state.activated = False agent._interrupt_state.context = {} + agent._cancel_signal = threading.Event() return agent diff --git a/tests_integ/test_cancellation.py b/tests_integ/test_cancellation.py new file mode 100644 index 000000000..1f0b7b1c1 --- /dev/null +++ b/tests_integ/test_cancellation.py @@ -0,0 +1,156 @@ +"""Integration tests for agent cancellation with Amazon Bedrock. + +These tests verify that cancellation works correctly with the Bedrock model provider. +They require valid AWS credentials and may incur API costs. + +To run these tests: + hatch run test-integ tests_integ/test_cancellation.py +""" + +import asyncio +import os +import threading + +import pytest + +from strands import Agent, tool +from strands.hooks import AfterModelCallEvent, BeforeModelCallEvent +from strands.models import BedrockModel + +# Skip all tests if no AWS credentials are available +pytestmark = [ + pytest.mark.skipif(not os.getenv("AWS_REGION"), reason="AWS credentials not available"), + pytest.mark.asyncio, +] + + +async def test_cancel_with_bedrock(): + """Test agent.cancel() with Amazon Bedrock model. + + Verifies that cancellation works correctly with a real Bedrock + model by cancelling before the model call starts. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + # Cancel deterministically before the model call + async def cancel_before_model(event: BeforeModelCallEvent): + agent.cancel() + + agent.add_hook(cancel_before_model, BeforeModelCallEvent) + + result = await agent.invoke_async( + "Write a detailed 1000-word essay about the history of space exploration, " + "including major milestones, key figures, and technological breakthroughs." + ) + + assert result.stop_reason == "cancelled" + assert result.message["role"] == "assistant" + assert result.message["content"] == [{"text": "Cancelled by user"}] + + +async def test_cancel_during_streaming_bedrock(): + """Test agent.cancel() during streaming with Bedrock. + + Verifies that cancellation works correctly when using the + streaming API with a real Bedrock model. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + events = [] + async for event in agent.stream_async( + "Write a detailed story about a space adventure. Make it at least 500 words long." + ): + events.append(event) + # Cancel after receiving the first model delta event + if "data" in event: + agent.cancel() + if event.get("result"): + break + + # Find the result event + result_event = next((e for e in events if e.get("result")), None) + assert result_event is not None + assert result_event["result"].stop_reason == "cancelled" + + +async def test_cancel_with_tools_bedrock(): + """Test agent.cancel() during tool execution with Bedrock. + + Verifies that cancellation works correctly when the agent + is executing tools with a real Bedrock model. + """ + + @tool + async def slow_calculation(x: int, y: int) -> int: + """Perform a slow calculation that takes time. + + Args: + x: First number + y: Second number + + Returns: + The sum of x and y + """ + await asyncio.sleep(2) + return x + y + + @tool + async def another_calculation(a: int, b: int) -> int: + """Another slow calculation. + + Args: + a: First number + b: Second number + + Returns: + The product of a and b + """ + await asyncio.sleep(2) + return a * b + + agent = Agent( + model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0"), + tools=[slow_calculation, another_calculation], + ) + + # Cancel deterministically after model returns tool_use + async def cancel_after_model(event: AfterModelCallEvent): + if event.stop_response and event.stop_response.stop_reason == "tool_use": + agent.cancel() + + agent.add_hook(cancel_after_model, AfterModelCallEvent) + + result = await agent.invoke_async( + "Please use the slow_calculation tool to add 5 and 10, then use another_calculation to multiply 3 and 7." + ) + + assert result.stop_reason == "cancelled" + + +async def test_cancel_from_thread_bedrock(): + """Test agent.cancel() from a different thread with Bedrock. + + Simulates a real-world scenario where cancellation is triggered + from a different thread (e.g., a web request handler) while the agent + is executing. + """ + + agent = Agent(model=BedrockModel(model_id="anthropic.claude-3-haiku-20240307-v1:0")) + + # Cancel deterministically from a different thread before the model call + def cancel_before_model(event: BeforeModelCallEvent): + thread = threading.Thread(target=agent.cancel) + thread.start() + thread.join() + + agent.add_hook(cancel_before_model, BeforeModelCallEvent) + + result = await agent.invoke_async( + "Write a comprehensive guide about machine learning, " + "covering supervised learning, unsupervised learning, and deep learning. " + "Make it at least 800 words." + ) + + assert result.stop_reason == "cancelled" From 32d703cf833e663c452d2c98881dff8a3e22af26 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 9 Mar 2026 14:58:02 -0400 Subject: [PATCH 172/279] feat(session): optimize session manager initialization (#1829) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- .../session/repository_session_manager.py | 22 +- .../test_repository_session_manager.py | 288 +++++++++++++++--- 2 files changed, 273 insertions(+), 37 deletions(-) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 0d49c847d..b3eed6474 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -52,8 +52,11 @@ def __init__( # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) + self._is_new_session = True session = Session(session_id=session_id, session_type=SessionType.AGENT) session_repository.create_session(session) + else: + self._is_new_session = False self.session = session @@ -170,7 +173,11 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( @@ -299,7 +306,12 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non source: Multi-agent source object to restore state into **kwargs: Additional keyword arguments for future extensibility. """ - state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + # Skip read_multi_agent call for new sessions since no multi-agents can exist yet + if self._is_new_session: + state = None + else: + state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs) + if state is None: self.session_repository.create_multi_agent(self.session_id, source, **kwargs) else: @@ -317,7 +329,11 @@ def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: raise SessionException("The `agent_id` of an agent must be unique in a session.") self._latest_agent_message[agent.agent_id] = None - session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + # Skip read_agent call for new sessions since no agents can exist yet + if self._is_new_session: + session_agent = None + else: + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) if session_agent is None: logger.debug( diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index f8f044a9b..9b2d84a51 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -28,6 +28,15 @@ def session_manager(mock_repository): return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) +@pytest.fixture +def existing_session_manager(mock_repository): + """Create a session manager with a pre-existing session in the repository.""" + # Create session first so the manager sees it as existing + session = Session(session_id="test-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + @pytest.fixture def agent(): """Create a mock agent.""" @@ -100,7 +109,7 @@ def test_initialize_multiple_agents_without_id(session_manager, agent): session_manager.initialize(agent2) -def test_initialize_restores_existing_agent(session_manager, agent): +def test_initialize_restores_existing_agent(existing_session_manager, agent): """Test that initializing an existing agent restores its state.""" # Set agent ID agent.agent_id = "existing-agent" @@ -112,7 +121,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): conversation_manager_state=SlidingWindowConversationManager().get_state(), _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -122,10 +131,10 @@ def test_initialize_restores_existing_agent(session_manager, agent): }, message_id=0, ) - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -135,7 +144,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) -def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): +def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(existing_session_manager): """Test that initializing an existing agent restores its state.""" conversation_manager = SummarizingConversationManager() conversation_manager.removed_message_count = 1 @@ -147,7 +156,7 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Create some messages message = SessionMessage( @@ -158,13 +167,13 @@ def test_initialize_restores_existing_agent_with_summarizing_conversation_manage message_id=0, ) # Create two messages as one will be removed by the conversation manager - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) message.message_id = 1 - session_manager.session_repository.create_message("test-session", "existing-agent", message) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) # Initialize agent agent = Agent(agent_id="existing-agent", conversation_manager=SummarizingConversationManager()) - session_manager.initialize(agent) + existing_session_manager.initialize(agent) # Verify agent state restored assert agent.state.get("key") == "value" @@ -217,26 +226,26 @@ def test_initialize_multi_agent_new(session_manager, mock_multi_agent): assert state["state"] == {"key": "value"} -def test_initialize_multi_agent_existing(session_manager, mock_multi_agent): +def test_initialize_multi_agent_existing(existing_session_manager, mock_multi_agent): """Test initializing existing multi-agent state.""" # Create existing state first - session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) + existing_session_manager.session_repository.create_multi_agent("test-session", mock_multi_agent) # Create a mock with updated state for the update call updated_mock = Mock() updated_mock.id = "test-multi-agent" existing_state = {"id": "test-multi-agent", "state": {"restored": "data"}} updated_mock.serialize_state.return_value = existing_state - session_manager.session_repository.update_multi_agent("test-session", updated_mock) + existing_session_manager.session_repository.update_multi_agent("test-session", updated_mock) # Initialize multi-agent - session_manager.initialize_multi_agent(mock_multi_agent) + existing_session_manager.initialize_multi_agent(mock_multi_agent) # Verify deserialize_state was called with existing state mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) -def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): +def test_fix_broken_tool_use_adds_missing_tool_results(existing_session_manager): """Test that _fix_broken_tool_use adds missing toolResult messages.""" conversation_manager = SlidingWindowConversationManager() @@ -246,7 +255,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -261,11 +270,13 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -277,7 +288,7 @@ def test_fix_broken_tool_use_adds_missing_tool_results(session_manager): assert fixed_messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): +def test_fix_broken_tool_use_extends_partial_tool_results(existing_session_manager): """Test fixing messages where some toolResults are missing.""" conversation_manager = SlidingWindowConversationManager() # Create agent in repository first @@ -286,7 +297,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -309,11 +320,13 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -330,7 +343,7 @@ def test_fix_broken_tool_use_extends_partial_tool_results(session_manager): assert missing_result["toolResult"]["content"][0]["text"] == "Tool was interrupted." -def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): +def test_fix_broken_tool_use_handles_multiple_orphaned_tools(existing_session_manager): """Test fixing multiple orphaned toolUse messages.""" conversation_manager = SlidingWindowConversationManager() @@ -340,7 +353,7 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): state={"key": "value"}, conversation_manager_state=conversation_manager.get_state(), ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) broken_messages = [ { @@ -358,11 +371,13 @@ def test_fix_broken_tool_use_handles_multiple_orphaned_tools(session_manager): message=broken_message, message_id=index, ) - session_manager.session_repository.create_message("test-session", "existing-agent", broken_session_message) + existing_session_manager.session_repository.create_message( + "test-session", "existing-agent", broken_session_message + ) # Initialize agent agent = Agent(agent_id="existing-agent") - session_manager.initialize(agent) + existing_session_manager.initialize(agent) fixed_messages = agent.messages @@ -449,7 +464,7 @@ def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): assert messages[0].message["role"] == "user" -def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): +def test_initialize_bidi_agent_restores_existing(existing_session_manager, mock_bidi_agent): """Test initializing BidiAgent restores from existing session.""" # Create existing session data session_agent = SessionAgent( @@ -457,16 +472,16 @@ def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agen state={"restored": "state"}, conversation_manager_state={}, # Empty for BidiAgent ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add messages msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "Message 1"}]}, 0) msg2 = SessionMessage.from_message({"role": "assistant", "content": [{"text": "Response 1"}]}, 1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify state restored assert mock_bidi_agent.state.get() == {"restored": "state"} @@ -532,7 +547,7 @@ def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): session_manager.initialize_bidi_agent(agent2) -def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): +def test_bidi_agent_messages_with_offset_zero(existing_session_manager, mock_bidi_agent): """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" # Create session with messages session_agent = SessionAgent( @@ -540,15 +555,15 @@ def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): state={}, conversation_manager_state={}, ) - session_manager.session_repository.create_agent("test-session", session_agent) + existing_session_manager.session_repository.create_agent("test-session", session_agent) # Add 5 messages for i in range(5): msg = SessionMessage.from_message({"role": "user", "content": [{"text": f"Message {i}"}]}, i) - session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) + existing_session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) # Initialize agent - session_manager.initialize_bidi_agent(mock_bidi_agent) + existing_session_manager.initialize_bidi_agent(mock_bidi_agent) # Verify all messages restored (offset=0, no removed_message_count) assert len(mock_bidi_agent.messages) == 5 @@ -811,3 +826,208 @@ def tracking_update_agent(session_id, session_agent): # First sync should always update (no previous state) session_manager.sync_agent(agent) assert len(update_agent_calls) == 1 + + +# ============================================================================ +# New Session Optimization Tests (Issue #1828) +# ============================================================================ + + +def test_is_new_session_true_when_session_created(mock_repository): + """Test that _is_new_session is True when creating a new session.""" + # Session doesn't exist yet + assert mock_repository.read_session("new-session") is None + + # Creating manager should set _is_new_session to True + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + + assert manager._is_new_session is True + + +def test_is_new_session_false_when_session_exists(mock_repository): + """Test that _is_new_session is False when using an existing session.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Creating manager should set _is_new_session to False + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + + assert manager._is_new_session is False + + +def test_initialize_skips_read_agent_for_new_session(mock_repository): + """Test that initialize() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Initialize agent + agent = Agent(agent_id="test-agent") + manager.initialize(agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "test-agent") + + +def test_initialize_bidi_agent_skips_read_agent_for_new_session(mock_repository): + """Test that initialize_bidi_agent() skips read_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should NOT be called for new session + assert len(read_agent_calls) == 0 + + +def test_initialize_bidi_agent_calls_read_agent_for_existing_session(mock_repository): + """Test that initialize_bidi_agent() calls read_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_agent calls + read_agent_calls = [] + original_read_agent = mock_repository.read_agent + + def tracking_read_agent(session_id, agent_id): + read_agent_calls.append((session_id, agent_id)) + return original_read_agent(session_id, agent_id) + + mock_repository.read_agent = tracking_read_agent + + # Create mock BidiAgent + bidi_agent = Mock() + bidi_agent.agent_id = "bidi-agent-1" + bidi_agent.messages = [{"role": "user", "content": [{"text": "Hello!"}]}] + bidi_agent.state = AgentState({}) + + # Initialize bidi agent + manager.initialize_bidi_agent(bidi_agent) + + # read_agent should be called for existing session + assert len(read_agent_calls) == 1 + assert read_agent_calls[0] == ("existing-session", "bidi-agent-1") + + +def test_initialize_multi_agent_skips_read_for_new_session(mock_repository): + """Test that initialize_multi_agent() skips read_multi_agent() call when _is_new_session is True.""" + # Create manager (new session) + manager = RepositorySessionManager(session_id="new-session", session_repository=mock_repository) + assert manager._is_new_session is True + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should NOT be called for new session + assert len(read_multi_agent_calls) == 0 + + +def test_initialize_multi_agent_calls_read_for_existing_session(mock_repository): + """Test that initialize_multi_agent() calls read_multi_agent() when _is_new_session is False.""" + # Create session first + session = Session(session_id="existing-session", session_type=SessionType.AGENT) + mock_repository.create_session(session) + + # Create manager (existing session) + manager = RepositorySessionManager(session_id="existing-session", session_repository=mock_repository) + assert manager._is_new_session is False + + # Track read_multi_agent calls + read_multi_agent_calls = [] + original_read_multi_agent = mock_repository.read_multi_agent + + def tracking_read_multi_agent(session_id, multi_agent_id, **kwargs): + read_multi_agent_calls.append((session_id, multi_agent_id)) + return original_read_multi_agent(session_id, multi_agent_id, **kwargs) + + mock_repository.read_multi_agent = tracking_read_multi_agent + + # Create mock multi-agent + multi_agent = Mock() + multi_agent.id = "test-multi-agent" + multi_agent.serialize_state.return_value = {"id": "test-multi-agent", "state": {}} + + # Initialize multi-agent + manager.initialize_multi_agent(multi_agent) + + # read_multi_agent should be called for existing session + assert len(read_multi_agent_calls) == 1 + assert read_multi_agent_calls[0] == ("existing-session", "test-multi-agent") From 3406ef4a5a500a6f28c6825b7989f41f3b979599 Mon Sep 17 00:00:00 2001 From: Jack Walker <250010855+jackatorcflo@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:17:41 -0500 Subject: [PATCH 173/279] fix(mistral): report usage metrics in streaming mode (#1697) --- src/strands/models/mistral.py | 4 +-- tests/strands/models/test_mistral.py | 32 +++++++++++++++++++++--- tests_integ/models/test_model_mistral.py | 5 ++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 504e81c92..f44a11d30 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -496,8 +496,8 @@ async def stream( yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - if hasattr(chunk, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage}) + if hasattr(chunk, "data") and hasattr(chunk.data, "usage") and chunk.data.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage}) except Exception as e: if "rate" in str(e).lower() or "429" in str(e): diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index ad74bae89..57189748e 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -451,9 +451,9 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) @@ -476,6 +476,30 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning assert len(captured_warnings) == 0 +@pytest.mark.asyncio +async def test_stream_no_usage(mistral_client, model, agenerator, alist): + mock_event = unittest.mock.Mock( + data=unittest.mock.Mock( + choices=[ + unittest.mock.Mock( + delta=unittest.mock.Mock(content="test stream", tool_calls=None), + finish_reason="end_turn", + ) + ], + usage=None, + ), + ) + + mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None) + + # Should complete without error and not yield a metadata chunk + chunks = await alist(response) + assert not any("metadata" in c for c in chunks if isinstance(c, dict)) + + @pytest.mark.asyncio async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings): tool_choice = {"auto": {}} @@ -492,9 +516,9 @@ async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) diff --git a/tests_integ/models/test_model_mistral.py b/tests_integ/models/test_model_mistral.py index 3b13e5911..83f6af499 100644 --- a/tests_integ/models/test_model_mistral.py +++ b/tests_integ/models/test_model_mistral.py @@ -106,6 +106,11 @@ async def test_agent_stream_async(agent): assert all(string in text for string in ["12:00", "sunny"]) + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + def test_agent_structured_output(non_streaming_agent, weather): tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") From 021344b2015aeea441bc843b3b9cb378ff5aab67 Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Mon, 9 Mar 2026 23:38:37 +0100 Subject: [PATCH 174/279] fix(openai_responses): use output_text for assistant messages in multi-turn conversations (#1851) Signed-off-by: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Co-authored-by: giulio-leone Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/strands/models/openai_responses.py | 13 +++-- tests/strands/models/test_openai_responses.py | 54 ++++++++++++++++++- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 96d4bee59..0ace9645f 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -50,7 +50,7 @@ import openai # noqa: E402 - must import after version check -from ..types.content import ContentBlock, Messages # noqa: E402 +from ..types.content import ContentBlock, Messages, Role # noqa: E402 from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 @@ -467,7 +467,7 @@ def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: contents = message["content"] formatted_contents = [ - cls._format_request_message_content(content) + cls._format_request_message_content(content, role=role) for content in contents if not any(block_type in content for block_type in ["toolResult", "toolUse"]) ] @@ -502,11 +502,15 @@ def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: ] @classmethod - def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def _format_request_message_content( + cls, content: ContentBlock, *, role: Role = "user" + ) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: content: Message content. + role: Message role ("user" or "assistant"). Controls text content + type: "input_text" for user, "output_text" for assistant. Returns: OpenAI compatible content block. @@ -526,7 +530,8 @@ def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any return {"type": "input_image", "image_url": data_url} if "text" in content: - return {"type": "input_text", "text": content["text"]} + text_type = "output_text" if role == "assistant" else "input_text" + return {"type": text_type, "text": content["text"]} raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 7b09f1b68..9c84f4ed4 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -248,7 +248,7 @@ def test_format_request_messages(system_prompt): }, { "role": "assistant", - "content": [{"type": "input_text", "text": "call tool"}], + "content": [{"type": "output_text", "text": "call tool"}], }, { "type": "function_call", @@ -265,6 +265,58 @@ def test_format_request_messages(system_prompt): assert tru_result == exp_result +def test_format_request_messages_assistant_text_uses_output_text(): + """Assistant text content must use output_text, not input_text. + + Regression test for multi-turn conversations failing because the OpenAI + Responses API rejects input_text in assistant messages. + See: https://github.com/strands-agents/sdk-python/issues/1850 + """ + messages = [ + { + "content": [{"text": "Say hello"}], + "role": "user", + }, + { + "content": [{"text": "Hello!"}], + "role": "assistant", + }, + { + "content": [{"text": "Say goodbye"}], + "role": "user", + }, + ] + + result = OpenAIResponsesModel._format_request_messages(messages) + + assert result[0] == { + "role": "user", + "content": [{"type": "input_text", "text": "Say hello"}], + } + assert result[1] == { + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!"}], + } + assert result[2] == { + "role": "user", + "content": [{"type": "input_text", "text": "Say goodbye"}], + } + + +def test_format_request_message_content_role_assistant(): + """_format_request_message_content uses output_text for assistant role.""" + content = {"text": "response text"} + result = OpenAIResponsesModel._format_request_message_content(content, role="assistant") + assert result == {"type": "output_text", "text": "response text"} + + +def test_format_request_message_content_role_user(): + """_format_request_message_content uses input_text for user role (default).""" + content = {"text": "question"} + result = OpenAIResponsesModel._format_request_message_content(content, role="user") + assert result == {"type": "input_text", "text": "question"} + + def test_format_request(model, messages, tool_specs, system_prompt): tru_request = model._format_request(messages, tool_specs, system_prompt) exp_request = { From bfe9d02001178ad1d6c29d0b92bf67411f637528 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 9 Mar 2026 22:03:27 -0400 Subject: [PATCH 175/279] feat(hooks): add resume flag to AfterInvocationEvent (#1767) --- src/strands/agent/agent.py | 91 ++++--- src/strands/hooks/events.py | 15 ++ tests/strands/agent/hooks/test_events.py | 30 +++ tests/strands/agent/test_agent_hooks.py | 327 +++++++++++++++++++++++ 4 files changed, 426 insertions(+), 37 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8f4167d9b..f378a886a 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -814,49 +814,66 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( - BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=messages) - ) - messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages + current_messages: Messages | None = messages - agent_result: AgentResult | None = None - try: - yield InitEventLoopEvent() + while current_messages is not None: + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, invocation_state=invocation_state, messages=current_messages) + ) + current_messages = ( + before_invocation_event.messages if before_invocation_event.messages is not None else current_messages + ) - await self._append_messages(*messages) + agent_result: AgentResult | None = None + try: + yield InitEventLoopEvent() - structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model, - structured_output_prompt=structured_output_prompt or self._structured_output_prompt, - ) + await self._append_messages(*current_messages) - # Execute the event loop cycle with retry logic for context limits - events = self._execute_event_loop_cycle(invocation_state, structured_output_context) - async for event in events: - # Signal from the model provider that the message sent by the user should be redacted, - # likely due to a guardrail. - if ( - isinstance(event, ModelStreamChunkEvent) - and event.chunk - and event.chunk.get("redactContent") - and event.chunk["redactContent"].get("redactUserContentMessage") - ): - self.messages[-1]["content"] = self._redact_user_content( - self.messages[-1]["content"], str(event.chunk["redactContent"]["redactUserContentMessage"]) - ) - if self._session_manager: - self._session_manager.redact_latest_message(self.messages[-1], self) - yield event + structured_output_context = StructuredOutputContext( + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, + ) - # Capture the result from the final event if available - if isinstance(event, EventLoopStopEvent): - agent_result = AgentResult(*event["stop"]) + # Execute the event loop cycle with retry logic for context limits + events = self._execute_event_loop_cycle(invocation_state, structured_output_context) + async for event in events: + # Signal from the model provider that the message sent by the user should be redacted, + # likely due to a guardrail. + if ( + isinstance(event, ModelStreamChunkEvent) + and event.chunk + and event.chunk.get("redactContent") + and event.chunk["redactContent"].get("redactUserContentMessage") + ): + self.messages[-1]["content"] = self._redact_user_content( + self.messages[-1]["content"], + str(event.chunk["redactContent"]["redactUserContentMessage"]), + ) + if self._session_manager: + self._session_manager.redact_latest_message(self.messages[-1], self) + yield event + + # Capture the result from the final event if available + if isinstance(event, EventLoopStopEvent): + agent_result = AgentResult(*event["stop"]) - finally: - self.conversation_manager.apply_management(self) - await self.hooks.invoke_callbacks_async( - AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) - ) + finally: + self.conversation_manager.apply_management(self) + after_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + AfterInvocationEvent(agent=self, invocation_state=invocation_state, result=agent_result) + ) + + # Convert resume input to messages for next iteration, or None to stop + if after_invocation_event.resume is not None: + logger.debug("resume= | hook requested agent resume with new input") + # If in interrupt state, process interrupt responses before continuing. + # This mirrors the _interrupt_state.resume() call in stream_async and will + # raise TypeError if the resume input is not valid interrupt responses. + self._interrupt_state.resume(after_invocation_event.resume) + current_messages = await self._convert_prompt_to_messages(after_invocation_event.resume) + else: + current_messages = None async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8d3e5d280..9186e0e70 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult +from ..types.agent import AgentInput from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason @@ -78,6 +79,13 @@ class AfterInvocationEvent(HookEvent): - Agent.stream_async - Agent.structured_output + Resume: + When ``resume`` is set to a non-None value by a hook callback, the agent will + automatically re-invoke itself with the provided input. This enables hooks to + implement autonomous looping patterns where the agent continues processing + based on its previous result. The resume triggers a full new invocation cycle + including ``BeforeInvocationEvent``. + Attributes: invocation_state: State and configuration passed through the agent invocation. This can include shared context for multi-agent coordination, request tracking, @@ -85,10 +93,17 @@ class AfterInvocationEvent(HookEvent): result: The result of the agent invocation, if available. This will be None when invoked from structured_output methods, as those return typed output directly rather than AgentResult. + resume: When set to a non-None agent input by a hook callback, the agent will + re-invoke itself with this input. The value can be any valid AgentInput + (str, content blocks, messages, etc.). Defaults to None (no resume). """ invocation_state: dict[str, Any] = field(default_factory=dict) result: "AgentResult | None" = None + resume: AgentInput = None + + def _can_write(self, name: str) -> bool: + return name == "resume" @property def should_reverse_callbacks(self) -> bool: diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index de551d137..0e03fbbcd 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -230,3 +230,33 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes """Test that BeforeInvocationEvent.agent is not writable.""" with pytest.raises(AttributeError, match="Property agent is not writable"): start_request_event_with_messages.agent = Mock() + + +def test_after_invocation_event_resume_defaults_to_none(agent): + """Test that AfterInvocationEvent.resume defaults to None.""" + event = AfterInvocationEvent(agent=agent, result=None) + assert event.resume is None + + +def test_after_invocation_event_resume_is_writable(agent): + """Test that AfterInvocationEvent.resume can be set by hooks.""" + event = AfterInvocationEvent(agent=agent, result=None) + event.resume = "continue with this input" + assert event.resume == "continue with this input" + + +def test_after_invocation_event_resume_accepts_various_input_types(agent): + """Test that resume accepts all AgentInput types.""" + event = AfterInvocationEvent(agent=agent, result=None) + + # String input + event.resume = "hello" + assert event.resume == "hello" + + # Content block list + event.resume = [{"text": "hello"}] + assert event.resume == [{"text": "hello"}] + + # None to stop + event.resume = None + assert event.resume is None diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 4397b9628..1da245d70 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -694,3 +694,330 @@ async def capture_messages_hook(event: BeforeInvocationEvent): # structured_output_async uses deprecated path that doesn't pass messages assert received_messages is None + + +def test_after_invocation_resume_triggers_new_invocation(): + """Test that setting resume on AfterInvocationEvent re-invokes the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + ] + ) + + resume_count = 0 + + async def resume_once(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count == 0: + resume_count += 1 + event.resume = "continue" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + result = agent("start") + + # Agent should have been invoked twice + assert resume_count == 1 + assert result.message["content"][0]["text"] == "Second response" + # 4 messages: user1, assistant1, user2 (resume), assistant2 + assert len(agent.messages) == 4 + assert agent.messages[0]["content"][0]["text"] == "start" + assert agent.messages[2]["content"][0]["text"] == "continue" + + +def test_after_invocation_resume_none_does_not_loop(): + """Test that resume=None (default) does not re-invoke the agent.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Only response"}]}, + ] + ) + + call_count = 0 + + async def no_resume(event: AfterInvocationEvent): + nonlocal call_count + call_count += 1 + # Don't set resume - should remain None + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, no_resume) + + result = agent("hello") + + assert call_count == 1 + assert result.message["content"][0]["text"] == "Only response" + + +def test_after_invocation_resume_fires_before_invocation_event(): + """Test that resume triggers BeforeInvocationEvent on each iteration.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + ] + ) + + before_invocation_count = 0 + after_invocation_count = 0 + + async def count_before(event: BeforeInvocationEvent): + nonlocal before_invocation_count + before_invocation_count += 1 + + async def resume_once(event: AfterInvocationEvent): + nonlocal after_invocation_count + after_invocation_count += 1 + if after_invocation_count == 1: + event.resume = "next" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, count_before) + agent.hooks.add_callback(AfterInvocationEvent, resume_once) + + agent("start") + + # BeforeInvocationEvent should fire for both the initial and resumed invocation + assert before_invocation_count == 2 + assert after_invocation_count == 2 + + +def test_after_invocation_resume_multiple_times(): + """Test that resume can chain multiple re-invocations.""" + mock_provider = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "Response 1"}]}, + {"role": "assistant", "content": [{"text": "Response 2"}]}, + {"role": "assistant", "content": [{"text": "Response 3"}]}, + ] + ) + + resume_count = 0 + + async def resume_twice(event: AfterInvocationEvent): + nonlocal resume_count + if resume_count < 2: + resume_count += 1 + event.resume = f"iteration {resume_count + 1}" + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(AfterInvocationEvent, resume_twice) + + result = agent("iteration 1") + + assert resume_count == 2 + assert result.message["content"][0]["text"] == "Response 3" + # 6 messages: 3 user + 3 assistant + assert len(agent.messages) == 6 + + +def test_after_invocation_resume_handles_interrupt_with_responses(): + """Test that a hook can handle an interrupt by resuming with interrupt responses.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (after interrupt resume): model gives final response + {"role": "assistant", "content": [{"text": "Completed after interrupt"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + """Interrupt before tool execution; returns stored response on second call.""" + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need human approval") + + async def handle_interrupt_via_resume(event: AfterInvocationEvent): + """Hook that automatically handles interrupts by resuming with responses.""" + if event.result and event.result.stop_reason == "interrupt": + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, handle_interrupt_via_resume) + + result = agent("do something") + + # The hook handled the interrupt automatically — agent completed normally + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Completed after interrupt" + # Interrupt state should be cleared after successful resume + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_with_invalid_input_during_interrupt(): + """Test that resuming with non-interrupt input while interrupt is active raises TypeError.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + async def resume_with_bad_input(event: AfterInvocationEvent): + """Hook that incorrectly tries to resume with a plain string during interrupt.""" + if event.result and event.result.stop_reason == "interrupt": + event.resume = "this is wrong" + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + agent.hooks.add_callback(AfterInvocationEvent, resume_with_bad_input) + + with pytest.raises(TypeError, match="must resume from interrupt with list of interruptResponse's"): + agent("do something") + + +def test_after_invocation_resume_interrupt_without_resume_returns_to_caller(): + """Test that an interrupt without resume set returns the interrupt to the caller.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: model calls the tool, which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Second invocation (caller resumes manually): final response + {"role": "assistant", "content": [{"text": "Done after manual resume"}]}, + ] + ) + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + # First call: hits interrupt, no hook handles it, returns to caller + result = agent("do something") + assert result.stop_reason == "interrupt" + assert len(result.interrupts) == 1 + assert result.interrupts[0].name == "approval_needed" + assert agent._interrupt_state.activated is True + + # Caller manually resumes with interrupt responses + interrupt_id = result.interrupts[0].id + result = agent([{"interruptResponse": {"interruptId": interrupt_id, "response": "yes"}}]) + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Done after manual resume" + assert agent._interrupt_state.activated is False + + +def test_after_invocation_resume_interrupt_during_resumed_invocation(): + """Test that an interrupt during a resumed invocation can be handled by the hook.""" + + @strands.tools.tool(name="interruptable_tool") + def interruptable_tool(value: str) -> str: + return value + + tool_use_id = "tool-1" + mock_provider = MockedModelProvider( + [ + # First invocation: simple text response (no tool call) + {"role": "assistant", "content": [{"text": "First response"}]}, + # Second invocation (resumed): triggers a tool call which will be interrupted + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": tool_use_id, + "name": "interruptable_tool", + "input": {"value": "test"}, + } + } + ], + }, + # Third invocation (after interrupt handled via resume): final response + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + ) + + invocation_count = 0 + + async def resume_hook(event: AfterInvocationEvent): + """Resume with new input on first call, handle interrupt on second.""" + nonlocal invocation_count + invocation_count += 1 + if invocation_count == 1: + # First invocation done, resume with new input + event.resume = "continue" + elif event.result and event.result.stop_reason == "interrupt": + # Second invocation hit interrupt, handle it + responses = [] + for interrupt in event.result.interrupts: + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "approved"}}) + event.resume = responses + + def interrupt_tool(event: BeforeToolCallEvent): + if event.tool_use["name"] == "interruptable_tool": + event.interrupt("approval_needed", reason="Need approval") + + agent = Agent(model=mock_provider, tools=[interruptable_tool], callback_handler=None) + agent.hooks.add_callback(AfterInvocationEvent, resume_hook) + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_tool) + + result = agent("start") + + # All three invocations happened within a single agent call + assert invocation_count == 3 + assert result.stop_reason == "end_turn" + assert result.message["content"][0]["text"] == "Final response" + assert agent._interrupt_state.activated is False From b0fc7961d5e0f6f2f494b17f7380ed84f43582d9 Mon Sep 17 00:00:00 2001 From: Kihyeon Myung <51226101+kevmyung@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:03:19 -0600 Subject: [PATCH 176/279] fix: place cache point on last user message instead of assistant (#1821) --- src/strands/models/bedrock.py | 14 +++--- tests/strands/models/test_bedrock.py | 75 +++++++++++++++++++++------- 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 3fa907995..bab4031ed 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -339,7 +339,7 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: - """Inject a cache point at the end of the last assistant message. + """Inject a cache point at the end of the last user message. Args: messages: List of messages to inject cache point into (modified in place). @@ -347,7 +347,7 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: if not messages: return - last_assistant_idx: int | None = None + last_user_idx: int | None = None for msg_idx, msg in enumerate(messages): content = msg.get("content", []) for block_idx, block in reversed(list(enumerate(content))): @@ -358,12 +358,12 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: msg_idx, block_idx, ) - if msg.get("role") == "assistant": - last_assistant_idx = msg_idx + if msg.get("role") == "user": + last_user_idx = msg_idx - if last_assistant_idx is not None and messages[last_assistant_idx].get("content"): - messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}}) - logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx) + if last_user_idx is not None and messages[last_user_idx].get("content"): + messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: """Find the index of the last user message containing text or image content. diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 66fe8ab00..89c4df70d 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2597,8 +2597,8 @@ def test_cache_strategy_none_for_non_claude(bedrock_client): assert model._cache_strategy is None -def test_inject_cache_point_adds_to_last_assistant(bedrock_client): - """Test that _inject_cache_point adds cache point to last assistant message.""" +def test_inject_cache_point_adds_to_last_user(bedrock_client): + """Test that _inject_cache_point adds cache point to last user message.""" model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") ) @@ -2611,13 +2611,14 @@ def test_inject_cache_point_adds_to_last_assistant(bedrock_client): model._inject_cache_point(cleaned_messages) - assert len(cleaned_messages[1]["content"]) == 2 - assert "cachePoint" in cleaned_messages[1]["content"][-1] - assert cleaned_messages[1]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" + assert len(cleaned_messages[1]["content"]) == 1 -def test_inject_cache_point_no_assistant_message(bedrock_client): - """Test that _inject_cache_point does nothing when no assistant message exists.""" +def test_inject_cache_point_single_user_message(bedrock_client): + """Test that _inject_cache_point adds cache point to single user message.""" model = BedrockModel( model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") ) @@ -2629,6 +2630,39 @@ def test_inject_cache_point_no_assistant_message(bedrock_client): model._inject_cache_point(cleaned_messages) assert len(cleaned_messages) == 1 + assert len(cleaned_messages[0]["content"]) == 2 + assert "cachePoint" in cleaned_messages[0]["content"][-1] + + +def test_inject_cache_point_empty_messages(bedrock_client): + """Test that _inject_cache_point handles empty messages list.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [] + model._inject_cache_point(cleaned_messages) + + assert cleaned_messages == [] + + +def test_inject_cache_point_with_tool_result_last_user(bedrock_client): + """Test that cache point is added to last user message even when it contains toolResult.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", cache_config=CacheConfig(strategy="auto") + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Use the tool"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "test_tool", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}]}, + ] + + model._inject_cache_point(cleaned_messages) + + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] + assert cleaned_messages[2]["content"][-1]["cachePoint"]["type"] == "default" assert len(cleaned_messages[0]["content"]) == 1 @@ -2643,6 +2677,8 @@ def test_inject_cache_point_skipped_for_non_claude(bedrock_client): formatted = model._format_bedrock_messages(messages) + assert len(formatted[0]["content"]) == 1 + assert "cachePoint" not in formatted[0]["content"][0] assert len(formatted[1]["content"]) == 1 assert "cachePoint" not in formatted[1]["content"][0] @@ -2664,8 +2700,8 @@ def test_format_bedrock_messages_does_not_mutate_original(bedrock_client): formatted = model._format_bedrock_messages(original_messages) assert original_messages == messages_before - assert "cachePoint" not in original_messages[1]["content"][-1] - assert "cachePoint" in formatted[1]["content"][-1] + assert "cachePoint" not in original_messages[2]["content"][-1] + assert "cachePoint" in formatted[2]["content"][-1] def test_inject_cache_point_strips_existing_cache_points(bedrock_client): @@ -2685,12 +2721,13 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client): model._inject_cache_point(cleaned_messages) # All old cache points should be stripped - assert len(cleaned_messages[0]["content"]) == 1 # user: only text + assert len(cleaned_messages[0]["content"]) == 1 # first user: only text assert len(cleaned_messages[1]["content"]) == 1 # first assistant: only text + assert len(cleaned_messages[3]["content"]) == 1 # last assistant: only text - # New cache point should be at end of last assistant message - assert len(cleaned_messages[3]["content"]) == 2 - assert "cachePoint" in cleaned_messages[3]["content"][-1] + # New cache point should be at end of last user message + assert len(cleaned_messages[2]["content"]) == 2 + assert "cachePoint" in cleaned_messages[2]["content"][-1] def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client): @@ -2707,9 +2744,10 @@ def test_inject_cache_point_anthropic_strategy_skips_model_check(bedrock_client) formatted = model._format_bedrock_messages(messages) - assert len(formatted[1]["content"]) == 2 - assert "cachePoint" in formatted[1]["content"][-1] - assert formatted[1]["content"][-1]["cachePoint"]["type"] == "default" + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert formatted[0]["content"][-1]["cachePoint"]["type"] == "default" + assert len(formatted[1]["content"]) == 1 def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedrock_client): @@ -2725,8 +2763,9 @@ def test_inject_cache_point_auto_strategy_resolves_to_anthropic_for_claude(bedro formatted = model._format_bedrock_messages(messages) - assert len(formatted[1]["content"]) == 2 - assert "cachePoint" in formatted[1]["content"][-1] + assert len(formatted[0]["content"]) == 2 + assert "cachePoint" in formatted[0]["content"][-1] + assert len(formatted[1]["content"]) == 1 def test_find_last_user_text_message_index_no_user_messages(bedrock_client): From 4a26f4a519906eeb925a1f0caf13cecc5a6d6e19 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Tue, 10 Mar 2026 14:29:21 -0400 Subject: [PATCH 177/279] feat(skills): Add agent skills (#1755) Co-authored-by: Nicholas Clegg Co-authored-by: Containerized Agent Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Nick Clegg --- AGENTS.md | 6 +- pyproject.toml | 1 + src/strands/__init__.py | 4 +- src/strands/plugins/__init__.py | 3 + src/strands/plugins/skills/__init__.py | 31 + src/strands/plugins/skills/agent_skills.py | 393 ++++++++++ src/strands/plugins/skills/skill.py | 377 ++++++++++ tests/strands/plugins/skills/__init__.py | 1 + .../plugins/skills/test_agent_skills.py | 699 ++++++++++++++++++ tests/strands/plugins/skills/test_skill.py | 561 ++++++++++++++ tests_integ/test_skills_plugin.py | 81 ++ 11 files changed, 2155 insertions(+), 2 deletions(-) create mode 100644 src/strands/plugins/skills/__init__.py create mode 100644 src/strands/plugins/skills/agent_skills.py create mode 100644 src/strands/plugins/skills/skill.py create mode 100644 tests/strands/plugins/skills/__init__.py create mode 100644 tests/strands/plugins/skills/test_agent_skills.py create mode 100644 tests/strands/plugins/skills/test_skill.py create mode 100644 tests_integ/test_skills_plugin.py diff --git a/AGENTS.md b/AGENTS.md index 10a66fcd7..21c32539c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,7 +130,11 @@ strands-agents/ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class │ │ ├── decorator.py # @hook decorator -│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ ├── registry.py # PluginRegistry for tracking plugins +│ │ └── skills/ # Agent Skills integration +│ │ ├── __init__.py # Skills package exports +│ │ ├── skill.py # Skill dataclass +│ │ └── agent_skills.py # AgentSkills plugin implementation │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling diff --git a/pyproject.toml b/pyproject.toml index b53194486..e07f3bac4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "mcp>=1.23.0,<2.0.0", "pydantic>=2.4.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", + "pyyaml>=6.0.0,<7.0.0", "watchdog>=6.0.0,<7.0.0", "opentelemetry-api>=1.30.0,<2.0.0", "opentelemetry-sdk>=1.30.0,<2.0.0", diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..3e1528fa6 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,17 +4,19 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import AgentSkills, Plugin, Skill from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", "AgentBase", + "AgentSkills", "agent", "models", "ModelRetryStrategy", "Plugin", + "Skill", "tool", "ToolContext", "types", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index c4b7c72c7..d7ca4c9b2 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -6,8 +6,11 @@ from .decorator import hook from .plugin import Plugin +from .skills import AgentSkills, Skill __all__ = [ + "AgentSkills", "Plugin", + "Skill", "hook", ] diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/plugins/skills/__init__.py new file mode 100644 index 000000000..f6cf8728b --- /dev/null +++ b/src/strands/plugins/skills/__init__.py @@ -0,0 +1,31 @@ +"""AgentSkills.io integration for Strands Agents. + +This module provides the AgentSkills plugin for integrating AgentSkills.io skills +into Strands agents. Skills enable progressive disclosure of instructions: +metadata is injected into the system prompt upfront, and full instructions +are loaded on demand via a tool. + +Example Usage: + ```python + from strands import Agent + from strands.plugins.skills import Skill, AgentSkills + + # Load from filesystem via classmethods + skill = Skill.from_file("./skills/pdf-processing") + skills = Skill.from_directory("./skills/") + + # Or let the plugin resolve paths automatically + plugin = AgentSkills(skills=["./skills/pdf-processing"]) + agent = Agent(plugins=[plugin]) + ``` +""" + +from .agent_skills import AgentSkills, SkillSource, SkillSources +from .skill import Skill + +__all__ = [ + "AgentSkills", + "Skill", + "SkillSource", + "SkillSources", +] diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/plugins/skills/agent_skills.py new file mode 100644 index 000000000..97ac86d93 --- /dev/null +++ b/src/strands/plugins/skills/agent_skills.py @@ -0,0 +1,393 @@ +"""AgentSkills plugin for integrating Agent Skills into Strands agents. + +This module provides the AgentSkills class that extends the Plugin base class +to add Agent Skills support. The plugin registers a tool for activating +skills, and injects skill metadata into the system prompt. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeAlias +from xml.sax.saxutils import escape + +from ...hooks.events import BeforeInvocationEvent +from ...plugins import Plugin, hook +from ...tools.decorator import tool +from ...types.tools import ToolContext +from .skill import Skill + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + +_DEFAULT_STATE_KEY = "agent_skills" +_RESOURCE_DIRS = ("scripts", "references", "assets") +_DEFAULT_MAX_RESOURCE_FILES = 20 + +SkillSource: TypeAlias = str | Path | Skill +"""A single skill source: path string, Path object, or Skill instance.""" + +SkillSources: TypeAlias = SkillSource | list[SkillSource] +"""One or more skill sources.""" + + +def _normalize_sources(sources: SkillSources) -> list[SkillSource]: + """Normalize a single source or list of sources into a list.""" + if isinstance(sources, list): + return sources + return [sources] + + +class AgentSkills(Plugin): + """Plugin that integrates Agent Skills into a Strands agent. + + The AgentSkills plugin extends the Plugin base class and provides: + + 1. A ``skills`` tool that allows the agent to activate skills on demand + 2. System prompt injection of available skill metadata before each invocation + 3. Session persistence of active skill state via ``agent.state`` + + Skills can be provided as filesystem paths (to individual skill directories or + parent directories containing multiple skills) or as pre-built ``Skill`` instances. + + Example: + ```python + from strands import Agent + from strands.plugins.skills import Skill, AgentSkills + + # Load from filesystem + plugin = AgentSkills(skills=["./skills/pdf-processing", "./skills/"]) + + # Or provide Skill instances directly + skill = Skill(name="my-skill", description="A custom skill", instructions="Do the thing") + plugin = AgentSkills(skills=[skill]) + + agent = Agent(plugins=[plugin]) + ``` + """ + + name = "agent_skills" + + def __init__( + self, + skills: SkillSources, + state_key: str = _DEFAULT_STATE_KEY, + max_resource_files: int = _DEFAULT_MAX_RESOURCE_FILES, + strict: bool = False, + ) -> None: + """Initialize the AgentSkills plugin. + + Args: + skills: One or more skill sources. Can be a single value or a list. Each element can be: + + - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) + - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) + - A ``Skill`` dataclass instance + state_key: Key used to store plugin state in ``agent.state``. + max_resource_files: Maximum number of resource files to list in skill responses. + strict: If True, raise on skill validation issues. If False (default), warn and load anyway. + """ + self._strict = strict + self._skills: dict[str, Skill] = self._resolve_skills(_normalize_sources(skills)) + self._state_key = state_key + self._max_resource_files = max_resource_files + super().__init__() + + def init_agent(self, agent: Agent) -> None: + """Initialize the plugin with an agent instance. + + Decorated hooks and tools are auto-registered by the plugin registry. + + Args: + agent: The agent instance to extend with skills support. + """ + if not self._skills: + logger.warning("no skills were loaded, the agent will have no skills available") + logger.debug("skill_count=<%d> | skills plugin initialized", len(self._skills)) + + @tool(context=True) + def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D417 + """Activate a skill to load its full instructions. + + Use this tool to load the complete instructions for a skill listed in + the available_skills section of your system prompt. + + Args: + skill_name: Name of the skill to activate. + """ + if not skill_name: + available = ", ".join(self._skills) + return f"Error: skill_name is required. Available skills: {available}" + + found = self._skills.get(skill_name) + if found is None: + available = ", ".join(self._skills) + return f"Skill '{skill_name}' not found. Available skills: {available}" + + logger.debug("skill_name=<%s> | skill activated", skill_name) + self._track_activated_skill(tool_context.agent, skill_name) + return self._format_skill_response(found) + + @hook + def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: + """Inject skill metadata into the system prompt before each invocation. + + Removes the previously injected XML block (if any) via exact string + replacement, then appends a fresh one. Uses agent state to track the + injected XML per-agent, so a single plugin instance can be shared + across multiple agents safely. + + Args: + event: The before-invocation event containing the agent reference. + """ + agent = event.agent + + current_prompt = agent.system_prompt or "" + + # Remove the previously injected XML block by exact match + state_data = agent.state.get(self._state_key) + last_injected_xml = state_data.get("last_injected_xml") if isinstance(state_data, dict) else None + if last_injected_xml is not None: + if last_injected_xml in current_prompt: + current_prompt = current_prompt.replace(last_injected_xml, "") + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + + skills_xml = self._generate_skills_xml() + injection = f"\n\n{skills_xml}" + new_prompt = f"{current_prompt}{injection}" if current_prompt else skills_xml + + new_injected_xml = injection if current_prompt else skills_xml + self._set_state_field(agent, "last_injected_xml", new_injected_xml) + agent.system_prompt = new_prompt + + def get_available_skills(self) -> list[Skill]: + """Get the list of available skills. + + Returns: + A copy of the current skills list. + """ + return list(self._skills.values()) + + def set_available_skills(self, skills: SkillSources) -> None: + """Set the available skills, replacing any existing ones. + + Each element can be a ``Skill`` instance, a ``str`` or ``Path`` to a + skill directory (containing SKILL.md), or a ``str`` or ``Path`` to a + parent directory containing skill subdirectories. + + Note: this does not persist state or deactivate skills on any agent. + Active skill state is managed per-agent and will be reconciled on the + next tool call or invocation. + + Args: + skills: One or more skill sources to resolve and set. + """ + self._skills = self._resolve_skills(_normalize_sources(skills)) + + + def _format_skill_response(self, skill: Skill) -> str: + """Format the tool response when a skill is activated. + + Includes the full instructions along with relevant metadata fields + and a listing of available resource files (scripts, references, assets) + for filesystem-based skills. + + Args: + skill: The activated skill. + + Returns: + Formatted string with skill instructions and metadata. + """ + if not skill.instructions: + return f"Skill '{skill.name}' activated (no instructions available)." + + parts: list[str] = [skill.instructions] + + metadata_lines: list[str] = [] + if skill.allowed_tools: + metadata_lines.append(f"Allowed tools: {', '.join(skill.allowed_tools)}") + if skill.compatibility: + metadata_lines.append(f"Compatibility: {skill.compatibility}") + if skill.path is not None: + metadata_lines.append(f"Location: {skill.path / 'SKILL.md'}") + + if metadata_lines: + parts.append("\n---\n" + "\n".join(metadata_lines)) + + if skill.path is not None: + resources = self._list_skill_resources(skill.path) + if resources: + parts.append("\nAvailable resources:\n" + "\n".join(f" {r}" for r in resources)) + + return "\n".join(parts) + + def _list_skill_resources(self, skill_path: Path) -> list[str]: + """List resource files in a skill's optional directories. + + Scans the ``scripts/``, ``references/``, and ``assets/`` subdirectories + for files, returning relative paths. Results are capped at + ``max_resource_files`` to avoid context bloat. + + Args: + skill_path: Path to the skill directory. + + Returns: + List of relative file paths (e.g. ``scripts/extract.py``). + """ + files: list[str] = [] + + for dir_name in _RESOURCE_DIRS: + resource_dir = skill_path / dir_name + if not resource_dir.is_dir(): + continue + + for file_path in sorted(resource_dir.rglob("*")): + if not file_path.is_file(): + continue + files.append(file_path.relative_to(skill_path).as_posix()) + if len(files) >= self._max_resource_files: + files.append(f"... (truncated at {self._max_resource_files} files)") + return files + + return files + + def _generate_skills_xml(self) -> str: + """Generate the XML block listing available skills for the system prompt. + + When no skills are loaded, returns a block indicating no skills are available. + Otherwise includes a ```` element for skills loaded from the filesystem, + following the AgentSkills.io integration spec. + + Returns: + XML-formatted string with skill metadata. + """ + if not self._skills: + return "\nNo skills are currently available.\n" + + lines: list[str] = [""] + + for skill in self._skills.values(): + lines.append("") + lines.append(f"{escape(skill.name)}") + lines.append(f"{escape(skill.description)}") + if skill.path is not None: + lines.append(f"{escape(str(skill.path / 'SKILL.md'))}") + lines.append("") + + lines.append("") + return "\n".join(lines) + + def _resolve_skills(self, sources: list[SkillSource]) -> dict[str, Skill]: + """Resolve a list of skill sources into Skill instances. + + Each source can be a Skill instance, a path to a skill directory, + or a path to a parent directory containing multiple skills. + + Args: + sources: List of skill sources to resolve. + + Returns: + Dict mapping skill names to Skill instances. + """ + resolved: dict[str, Skill] = {} + + for source in sources: + if isinstance(source, Skill): + if source.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", source.name) + resolved[source.name] = source + else: + path = Path(source).resolve() + if not path.exists(): + logger.warning("path=<%s> | skill source path does not exist, skipping", path) + continue + + if path.is_dir(): + # Check if this directory itself is a skill (has SKILL.md) + has_skill_md = (path / "SKILL.md").is_file() or (path / "skill.md").is_file() + + if has_skill_md: + try: + skill = Skill.from_file(path, strict=self._strict) + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) + resolved[skill.name] = skill + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + else: + # Treat as parent directory containing skill subdirectories + for skill in Skill.from_directory(path, strict=self._strict): + if skill.name in resolved: + logger.warning( + "name=<%s> | duplicate skill name, overwriting previous skill", skill.name + ) + resolved[skill.name] = skill + elif path.is_file() and path.name.lower() == "skill.md": + try: + skill = Skill.from_file(path, strict=self._strict) + if skill.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) + resolved[skill.name] = skill + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | failed to load skill: %s", path, e) + + logger.debug("source_count=<%d>, resolved_count=<%d> | skills resolved", len(sources), len(resolved)) + return resolved + + def _set_state_field(self, agent: Agent, key: str, value: Any) -> None: + """Set a single field in the plugin's agent state dict. + + Args: + agent: The agent whose state to update. + key: The state field key. + value: The value to set. + + Raises: + TypeError: If the existing state value is not a dict. + """ + state_data = agent.state.get(self._state_key) + if state_data is not None and not isinstance(state_data, dict): + raise TypeError(f"expected dict for state key '{self._state_key}', got {type(state_data).__name__}") + if state_data is None: + state_data = {} + state_data[key] = value + agent.state.set(self._state_key, state_data) + + def _track_activated_skill(self, agent: Agent, skill_name: str) -> None: + """Record a skill activation in agent state. + + Maintains an ordered list of activated skill names (most recent last), + without duplicates. + + Args: + agent: The agent whose state to update. + skill_name: Name of the activated skill. + """ + state_data = agent.state.get(self._state_key) + activated: list[str] = state_data.get("activated_skills", []) if isinstance(state_data, dict) else [] + if skill_name in activated: + activated.remove(skill_name) + activated.append(skill_name) + self._set_state_field(agent, "activated_skills", activated) + + def get_activated_skills(self, agent: Agent) -> list[str]: + """Get the list of skills activated by this agent. + + Returns skill names in activation order (most recent last). + + Args: + agent: The agent to query. + + Returns: + List of activated skill names. + """ + state_data = agent.state.get(self._state_key) + if isinstance(state_data, dict): + return list(state_data.get("activated_skills", [])) + return [] diff --git a/src/strands/plugins/skills/skill.py b/src/strands/plugins/skills/skill.py new file mode 100644 index 000000000..3e1b6bba5 --- /dev/null +++ b/src/strands/plugins/skills/skill.py @@ -0,0 +1,377 @@ +"""Skill data model and loading utilities for AgentSkills.io skills. + +This module defines the Skill dataclass and provides classmethods for +discovering, parsing, and loading skills from the filesystem or raw content. +Skills are directories containing a SKILL.md file with YAML frontmatter +metadata and markdown instructions. +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + +logger = logging.getLogger(__name__) + +_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$") +_MAX_SKILL_NAME_LENGTH = 64 + + +def _find_skill_md(skill_dir: Path) -> Path: + """Find the SKILL.md file in a skill directory. + + Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. + + Args: + skill_dir: Path to the skill directory. + + Returns: + Path to the SKILL.md file. + + Raises: + FileNotFoundError: If no SKILL.md file is found in the directory. + """ + for name in ("SKILL.md", "skill.md"): + candidate = skill_dir / name + if candidate.is_file(): + return candidate + + raise FileNotFoundError(f"path=<{skill_dir}> | no SKILL.md found in skill directory") + + +def _parse_frontmatter(content: str) -> tuple[dict[str, Any], str]: + """Parse YAML frontmatter and body from SKILL.md content. + + Extracts the YAML frontmatter between ``---`` delimiters at line boundaries + and returns parsed key-value pairs along with the remaining markdown body. + + Args: + content: Full content of a SKILL.md file. + + Returns: + Tuple of (frontmatter_dict, body_string). + + Raises: + ValueError: If the frontmatter is malformed or missing required delimiters. + """ + stripped = content.strip() + if not stripped.startswith("---"): + raise ValueError("SKILL.md must start with --- frontmatter delimiter") + + # Find the closing --- delimiter (first line after the opener that is only dashes) + match = re.search(r"\n^---\s*$", stripped, re.MULTILINE) + if match is None: + raise ValueError("SKILL.md frontmatter missing closing --- delimiter") + + frontmatter_str = stripped[3 : match.start()].strip() + body = stripped[match.end() :].strip() + + try: + result = yaml.safe_load(frontmatter_str) + except yaml.YAMLError: + # AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) + # to improve cross-client compatibility. See: agentskills.io/client-implementation/adding-skills-support + logger.warning("YAML parse failed, retrying with colon-quoting fallback") + fixed = _fix_yaml_colons(frontmatter_str) + result = yaml.safe_load(fixed) + + frontmatter: dict[str, Any] = result if isinstance(result, dict) else {} + return frontmatter, body + + +def _fix_yaml_colons(yaml_str: str) -> str: + """Attempt to fix common YAML issues like unquoted colons in values. + + Wraps values containing colons in double quotes to handle cases like: + ``description: Use this skill when: the user asks about PDFs`` + + Args: + yaml_str: The raw YAML string to fix. + + Returns: + The fixed YAML string. + """ + lines: list[str] = [] + for line in yaml_str.splitlines(): + # Match key: value where value contains another colon + match = re.match(r"^(\s*\w[\w-]*):\s+(.+)$", line) + if match: + key, value = match.group(1), match.group(2) + # If value contains a colon and isn't already quoted + if ":" in value and not (value.startswith('"') or value.startswith("'")): + line = f'{key}: "{value}"' + lines.append(line) + return "\n".join(lines) + + +def _validate_skill_name(name: str, dir_path: Path | None = None, *, strict: bool = False) -> None: + """Validate a skill name per the AgentSkills.io specification. + + In lenient mode (default), logs warnings for cosmetic issues but does not raise. + In strict mode, raises ValueError for any validation failure. + + Rules checked: + - 1-64 characters long + - Lowercase alphanumeric characters and hyphens only + - Cannot start or end with a hyphen + - No consecutive hyphens + - Must match parent directory name (if loaded from disk) + + Args: + name: The skill name to validate. + dir_path: Optional path to the skill directory for name matching. + strict: If True, raise ValueError on any issue. If False (default), log warnings. + + Raises: + ValueError: If the skill name is empty, or if strict=True and any rule is violated. + """ + if not name: + raise ValueError("Skill name cannot be empty") + + if len(name) > _MAX_SKILL_NAME_LENGTH: + msg = "name=<%s> | skill name exceeds %d character limit" + if strict: + raise ValueError(msg % (name, _MAX_SKILL_NAME_LENGTH)) + logger.warning(msg, name, _MAX_SKILL_NAME_LENGTH) + + if not _SKILL_NAME_PATTERN.match(name): + msg = ( + "name=<%s> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, " + "should not start/end with hyphen" + ) + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if "--" in name: + msg = "name=<%s> | skill name contains consecutive hyphens" + if strict: + raise ValueError(msg % name) + logger.warning(msg, name) + + if dir_path is not None and dir_path.name != name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (name, dir_path.name)) + logger.warning(msg, name, dir_path.name) + + +def _build_skill_from_frontmatter( + frontmatter: dict[str, Any], + body: str, +) -> Skill: + """Build a Skill instance from parsed frontmatter and body. + + Args: + frontmatter: Parsed YAML frontmatter dict. + body: Markdown body content. + + Returns: + A populated Skill instance. + """ + # Parse allowed-tools (space-delimited string or YAML list) + allowed_tools_raw = frontmatter.get("allowed-tools") or frontmatter.get("allowed_tools") + allowed_tools: list[str] | None = None + if isinstance(allowed_tools_raw, str) and allowed_tools_raw.strip(): + allowed_tools = allowed_tools_raw.strip().split() + elif isinstance(allowed_tools_raw, list): + allowed_tools = [str(item) for item in allowed_tools_raw if item] + + # Parse metadata (nested mapping) + metadata_raw = frontmatter.get("metadata", {}) + metadata: dict[str, Any] = {} + if isinstance(metadata_raw, dict): + metadata = {str(k): v for k, v in metadata_raw.items()} + + skill_license = frontmatter.get("license") + compatibility = frontmatter.get("compatibility") + + return Skill( + name=frontmatter["name"], + description=frontmatter["description"], + instructions=body, + allowed_tools=allowed_tools, + metadata=metadata, + license=str(skill_license) if skill_license else None, + compatibility=str(compatibility) if compatibility else None, + ) + + +@dataclass +class Skill: + r"""Represents an agent skill with metadata and instructions. + + A skill encapsulates a set of instructions and metadata that can be + dynamically loaded by an agent at runtime. Skills support progressive + disclosure: metadata is shown upfront in the system prompt, and full + instructions are loaded on demand via a tool. + + Skills can be created directly or via convenience classmethods:: + + # From a skill directory on disk + skill = Skill.from_file("./skills/my-skill") + + # From raw SKILL.md content + skill = Skill.from_content("---\nname: my-skill\n...") + + # Load all skills from a parent directory + skills = Skill.from_directory("./skills/") + + Attributes: + name: Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). + description: Human-readable description of what the skill does. + instructions: Full markdown instructions from the SKILL.md body. + path: Filesystem path to the skill directory, if loaded from disk. + allowed_tools: List of tool names the skill is allowed to use. (Experimental: not yet enforced) + metadata: Additional key-value metadata from the SKILL.md frontmatter. + license: License identifier (e.g., "Apache-2.0"). + compatibility: Compatibility information string. + """ + + name: str + description: str + instructions: str = "" + path: Path | None = None + allowed_tools: list[str] | None = None + metadata: dict[str, Any] = field(default_factory=dict) + license: str | None = None + compatibility: str | None = None + + @classmethod + def from_file(cls, skill_path: str | Path, *, strict: bool = False) -> Skill: + """Load a single skill from a directory containing SKILL.md. + + Resolves the filesystem path, reads the file content, and delegates + to ``from_content`` for parsing. After loading, sets the skill's + ``path`` and validates the skill name against the parent directory. + + Args: + skill_path: Path to the skill directory or the SKILL.md file itself. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the SKILL.md file. + + Raises: + FileNotFoundError: If the path does not exist or SKILL.md is not found. + ValueError: If the skill metadata is invalid. + """ + skill_path = Path(skill_path).resolve() + + if skill_path.is_file() and skill_path.name.lower() == "skill.md": + skill_md_path = skill_path + skill_dir = skill_path.parent + elif skill_path.is_dir(): + skill_dir = skill_path + skill_md_path = _find_skill_md(skill_dir) + else: + raise FileNotFoundError( + f"path=<{skill_path}> | skill path does not exist or is not a valid skill directory" + ) + + logger.debug("path=<%s> | loading skill", skill_md_path) + + content = skill_md_path.read_text(encoding="utf-8") + skill = cls.from_content(content, strict=strict) + + # Set path and check directory name match (from_content already validated the name format) + skill.path = skill_dir + if skill_dir.name != skill.name: + msg = "name=<%s>, directory=<%s> | skill name does not match parent directory name" + if strict: + raise ValueError(msg % (skill.name, skill_dir.name)) + logger.warning(msg, skill.name, skill_dir.name) + + logger.debug("name=<%s>, path=<%s> | skill loaded successfully", skill.name, skill.path) + return skill + + @classmethod + def from_content(cls, content: str, *, strict: bool = False) -> Skill: + """Parse SKILL.md content into a Skill instance. + + This is a convenience method for creating a Skill from raw SKILL.md + content (YAML frontmatter + markdown body) without requiring a file on + disk. + + Example:: + + content = '''--- + name: my-skill + description: Does something useful + --- + # Instructions + Follow these steps... + ''' + skill = Skill.from_content(content) + + Args: + content: Raw SKILL.md content with YAML frontmatter and markdown body. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + A Skill instance populated from the parsed content. + + Raises: + ValueError: If the content is missing required fields or has invalid frontmatter. + """ + frontmatter, body = _parse_frontmatter(content) + + name = frontmatter.get("name") + if not isinstance(name, str) or not name: + raise ValueError("SKILL.md content must have a 'name' field in frontmatter") + + description = frontmatter.get("description") + if not isinstance(description, str) or not description: + raise ValueError("SKILL.md content must have a 'description' field in frontmatter") + + _validate_skill_name(name, strict=strict) + + return _build_skill_from_frontmatter(frontmatter, body) + + @classmethod + def from_directory(cls, skills_dir: str | Path, *, strict: bool = False) -> list[Skill]: + """Load all skills from a parent directory containing skill subdirectories. + + Each subdirectory containing a SKILL.md file is treated as a skill. + Subdirectories without SKILL.md are silently skipped. + + Args: + skills_dir: Path to the parent directory containing skill subdirectories. + strict: If True, raise on any validation issue. If False (default), warn and load anyway. + + Returns: + List of Skill instances loaded from the directory. + + Raises: + FileNotFoundError: If the skills directory does not exist. + """ + skills_dir = Path(skills_dir).resolve() + + if not skills_dir.is_dir(): + raise FileNotFoundError(f"path=<{skills_dir}> | skills directory does not exist") + + skills: list[Skill] = [] + + for child in sorted(skills_dir.iterdir()): + if not child.is_dir(): + continue + + try: + _find_skill_md(child) + except FileNotFoundError: + logger.debug("path=<%s> | skipping directory without SKILL.md", child) + continue + + try: + skill = cls.from_file(child, strict=strict) + skills.append(skill) + except (ValueError, FileNotFoundError) as e: + logger.warning("path=<%s> | skipping skill due to error: %s", child, e) + + logger.debug("path=<%s>, count=<%d> | loaded skills from directory", skills_dir, len(skills)) + return skills diff --git a/tests/strands/plugins/skills/__init__.py b/tests/strands/plugins/skills/__init__.py new file mode 100644 index 000000000..9bd23c0ed --- /dev/null +++ b/tests/strands/plugins/skills/__init__.py @@ -0,0 +1 @@ +"""Tests for the skills plugin package.""" diff --git a/tests/strands/plugins/skills/test_agent_skills.py b/tests/strands/plugins/skills/test_agent_skills.py new file mode 100644 index 000000000..8c6ab10bd --- /dev/null +++ b/tests/strands/plugins/skills/test_agent_skills.py @@ -0,0 +1,699 @@ +"""Tests for the AgentSkills plugin.""" + +import logging +from pathlib import Path +from unittest.mock import MagicMock + +from strands.hooks.events import BeforeInvocationEvent +from strands.hooks.registry import HookRegistry +from strands.plugins.registry import _PluginRegistry +from strands.plugins.skills.agent_skills import AgentSkills +from strands.plugins.skills.skill import Skill +from strands.types.tools import ToolContext + + +def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): + """Helper to create a Skill instance.""" + return Skill(name=name, description=description, instructions=instructions) + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n# Instructions for {name}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +def _mock_agent(): + """Create a mock agent for testing.""" + agent = MagicMock() + agent._system_prompt = "You are an agent." + agent._system_prompt_content = [{"text": "You are an agent."}] + + # Make system_prompt property behave like the real Agent + type(agent).system_prompt = property( + lambda self: self._system_prompt, + lambda self, value: _set_system_prompt(self, value), + ) + + agent.hooks = HookRegistry() + agent.add_hook = MagicMock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) + agent.tool_registry = MagicMock() + agent.tool_registry.process_tools = MagicMock(return_value=["skills"]) + + # Use a real dict-backed state so get/set work correctly + state_store: dict[str, object] = {} + agent.state = MagicMock() + agent.state.get = MagicMock(side_effect=lambda key: state_store.get(key)) + agent.state.set = MagicMock(side_effect=lambda key, value: state_store.__setitem__(key, value)) + return agent + + +def _mock_tool_context(agent: MagicMock) -> ToolContext: + """Create a mock ToolContext with the given agent.""" + tool_use = {"toolUseId": "test-id", "name": "skills", "input": {}} + return ToolContext(tool_use=tool_use, agent=agent, invocation_state={"agent": agent}) + + +def _set_system_prompt(agent: MagicMock, value: str | None) -> None: + """Simulate the Agent.system_prompt setter.""" + if isinstance(value, str): + agent._system_prompt = value + agent._system_prompt_content = [{"text": value}] + elif value is None: + agent._system_prompt = None + agent._system_prompt_content = None + + +class TestSkillsPluginInit: + """Tests for AgentSkills initialization.""" + + def test_init_with_skill_instances(self): + """Test initialization with Skill instances.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "test-skill" + + def test_init_with_filesystem_paths(self, tmp_path): + """Test initialization with filesystem paths.""" + _make_skill_dir(tmp_path, "fs-skill") + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill")]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" + + def test_init_with_parent_directory(self, tmp_path): + """Test initialization with a parent directory containing skills.""" + _make_skill_dir(tmp_path, "skill-a") + _make_skill_dir(tmp_path, "skill-b") + plugin = AgentSkills(skills=[tmp_path]) + + assert len(plugin.get_available_skills()) == 2 + + def test_init_with_mixed_sources(self, tmp_path): + """Test initialization with mixed skill sources.""" + _make_skill_dir(tmp_path, "fs-skill") + direct_skill = _make_skill(name="direct-skill", description="Direct") + plugin = AgentSkills(skills=[str(tmp_path / "fs-skill"), direct_skill]) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"fs-skill", "direct-skill"} + + def test_init_skips_nonexistent_paths(self, tmp_path): + """Test that nonexistent paths are skipped gracefully.""" + plugin = AgentSkills(skills=[str(tmp_path / "nonexistent")]) + assert len(plugin.get_available_skills()) == 0 + + def test_init_empty_skills(self): + """Test initialization with empty skills list.""" + plugin = AgentSkills(skills=[]) + assert plugin.get_available_skills() == [] + + def test_name_attribute(self): + """Test that the plugin has the correct name.""" + plugin = AgentSkills(skills=[]) + assert plugin.name == "agent_skills" + + def test_custom_state_key(self): + """Test initialization with a custom state key.""" + plugin = AgentSkills(skills=[], state_key="custom_key") + assert plugin._state_key == "custom_key" + + def test_custom_max_resource_files(self): + """Test initialization with a custom max resource files limit.""" + plugin = AgentSkills(skills=[], max_resource_files=50) + assert plugin._max_resource_files == 50 + + +class TestSkillsPluginInitAgent: + """Tests for the init_agent method and plugin registry integration.""" + + def test_registers_tool(self): + """Test that the plugin registry registers the skills tool.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + registry = _PluginRegistry(agent) + registry.add_and_init(plugin) + + agent.tool_registry.process_tools.assert_called_once() + + def test_registers_hooks(self): + """Test that the plugin registry registers hook callbacks.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + registry = _PluginRegistry(agent) + registry.add_and_init(plugin) + + assert agent.hooks.has_callbacks() + + def test_does_not_store_agent_reference(self): + """Test that init_agent does not store the agent on the plugin.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + plugin.init_agent(agent) + + assert not hasattr(plugin, "_agent") + + +class TestSkillsPluginProperties: + """Tests for AgentSkills properties.""" + + def test_available_skills_getter_returns_copy(self): + """Test that get_available_skills returns a copy of the list.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + skills_list = plugin.get_available_skills() + skills_list.append(_make_skill(name="another-skill", description="Another")) + + assert len(plugin.get_available_skills()) == 1 + + def test_available_skills_setter(self): + """Test setting skills via set_available_skills.""" + plugin = AgentSkills(skills=[_make_skill()]) + + new_skill = _make_skill(name="new-skill", description="New") + plugin.set_available_skills([new_skill]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "new-skill" + + def test_set_available_skills_with_paths(self, tmp_path): + """Test setting skills via set_available_skills with filesystem paths.""" + plugin = AgentSkills(skills=[_make_skill()]) + _make_skill_dir(tmp_path, "fs-skill") + + plugin.set_available_skills([str(tmp_path / "fs-skill")]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "fs-skill" + + def test_set_available_skills_with_mixed_sources(self, tmp_path): + """Test setting skills via set_available_skills with mixed sources.""" + plugin = AgentSkills(skills=[]) + _make_skill_dir(tmp_path, "fs-skill") + direct = _make_skill(name="direct", description="Direct") + + plugin.set_available_skills([str(tmp_path / "fs-skill"), direct]) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"fs-skill", "direct"} + + + + +class TestSkillsTool: + """Tests for the skills tool method.""" + + def test_activate_skill(self): + """Test activating a skill returns its instructions.""" + skill = _make_skill(instructions="Full instructions here.") + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="test-skill", tool_context=tool_context) + + assert "Full instructions here." in result + + def test_activate_nonexistent_skill(self): + """Test activating a nonexistent skill returns error message.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="nonexistent", tool_context=tool_context) + + assert "not found" in result + assert "test-skill" in result + + def test_activate_replaces_previous(self): + """Test that activating a new skill replaces the previous one.""" + skill1 = _make_skill(name="skill-a", description="A", instructions="A instructions") + skill2 = _make_skill(name="skill-b", description="B", instructions="B instructions") + plugin = AgentSkills(skills=[skill1, skill2]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result_a = plugin.skills(skill_name="skill-a", tool_context=tool_context) + assert "A instructions" in result_a + + result_b = plugin.skills(skill_name="skill-b", tool_context=tool_context) + assert "B instructions" in result_b + + def test_activate_without_name(self): + """Test activating without a skill name returns error.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + result = plugin.skills(skill_name="", tool_context=tool_context) + + assert "required" in result.lower() + + def test_activate_tracks_in_agent_state(self): + """Test that activating a skill records it in agent state.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["test-skill"] + + def test_activate_multiple_tracks_order(self): + """Test that multiple activations are tracked in order.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-a", "skill-b"] + + def test_activate_same_skill_twice_deduplicates(self): + """Test that re-activating a skill moves it to the end without duplicates.""" + skill_a = _make_skill(name="skill-a", description="A", instructions="A") + skill_b = _make_skill(name="skill-b", description="B", instructions="B") + plugin = AgentSkills(skills=[skill_a, skill_b]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="skill-a", tool_context=tool_context) + plugin.skills(skill_name="skill-b", tool_context=tool_context) + plugin.skills(skill_name="skill-a", tool_context=tool_context) + + assert plugin.get_activated_skills(agent) == ["skill-b", "skill-a"] + + def test_get_activated_skills_empty_by_default(self): + """Test that get_activated_skills returns empty list when nothing activated.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + assert plugin.get_activated_skills(agent) == [] + + def test_get_activated_skills_returns_copy(self): + """Test that get_activated_skills returns a copy, not a reference.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + tool_context = _mock_tool_context(agent) + + plugin.skills(skill_name="test-skill", tool_context=tool_context) + result = plugin.get_activated_skills(agent) + result.append("injected") + + assert plugin.get_activated_skills(agent) == ["test-skill"] + + +class TestSystemPromptInjection: + """Tests for system prompt injection via hooks.""" + + def test_before_invocation_appends_skills_xml(self): + """Test that before_invocation appends skills XML to system prompt.""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + agent = _mock_agent() + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent.system_prompt + assert "test-skill" in agent.system_prompt + assert "A test skill" in agent.system_prompt + + def test_before_invocation_preserves_existing_prompt(self): + """Test that existing system prompt content is preserved.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert agent.system_prompt.startswith("Original prompt.") + assert "" in agent.system_prompt + + def test_repeated_invocations_do_not_accumulate(self): + """Test that repeated invocations rebuild from current prompt without accumulation.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + first_prompt = agent.system_prompt + + plugin._on_before_invocation(event) + second_prompt = agent.system_prompt + + assert first_prompt == second_prompt + + def test_no_skills_injects_empty_message(self): + """Test that a 'no skills available' message is injected when no skills are loaded.""" + plugin = AgentSkills(skills=[]) + agent = _mock_agent() + original_prompt = "Original prompt." + agent._system_prompt = original_prompt + agent._system_prompt_content = [{"text": original_prompt}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "No skills are currently available" in agent.system_prompt + assert agent.system_prompt.startswith("Original prompt.") + + def test_none_system_prompt_handled(self): + """Test handling when system prompt is None.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = None + agent._system_prompt_content = None + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "" in agent.system_prompt + + def test_preserves_other_plugin_modifications(self): + """Test that modifications by other plugins/hooks are preserved.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Simulate another plugin modifying the prompt + agent.system_prompt = agent.system_prompt + "\n\nExtra context from another plugin." + + plugin._on_before_invocation(event) + + assert "Extra context from another plugin." in agent.system_prompt + assert "" in agent.system_prompt + + def test_uses_public_system_prompt_setter(self): + """Test that the hook uses the public system_prompt setter.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original." + agent._system_prompt_content = [{"text": "Original."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # The public setter should have been used, so _system_prompt_content + # should be consistent with _system_prompt + assert agent._system_prompt_content == [{"text": agent._system_prompt}] + + def test_warns_when_previous_xml_not_found(self, caplog): + """Test that a warning is logged when the previously injected XML is missing from the prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Original prompt." + agent._system_prompt_content = [{"text": "Original prompt."}] + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Completely replace the system prompt, removing the injected XML + agent.system_prompt = "Totally new prompt." + + with caplog.at_level(logging.WARNING): + plugin._on_before_invocation(event) + + assert "unable to find previously injected skills XML in system prompt" in caplog.text + assert "" in agent.system_prompt + + +class TestSkillsXmlGeneration: + """Tests for _generate_skills_xml.""" + + def test_single_skill(self): + """Test XML generation with a single skill.""" + plugin = AgentSkills(skills=[_make_skill()]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "" in xml + assert "test-skill" in xml + assert "A test skill" in xml + + def test_multiple_skills(self): + """Test XML generation with multiple skills.""" + skills = [ + _make_skill(name="skill-a", description="Skill A"), + _make_skill(name="skill-b", description="Skill B"), + ] + plugin = AgentSkills(skills=skills) + xml = plugin._generate_skills_xml() + + assert "skill-a" in xml + assert "skill-b" in xml + + def test_empty_skills(self): + """Test XML generation with no skills includes 'no skills available' message.""" + plugin = AgentSkills(skills=[]) + xml = plugin._generate_skills_xml() + + assert "" in xml + assert "No skills are currently available" in xml + assert "" in xml + + def test_location_included_when_path_set(self, tmp_path): + """Test that location element is included when skill has a path.""" + skill = _make_skill() + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert f"{tmp_path / 'test-skill' / 'SKILL.md'}" in xml + + def test_location_omitted_when_path_none(self): + """Test that location element is omitted for programmatic skills.""" + skill = _make_skill() + assert skill.path is None + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "" not in xml + + def test_escapes_xml_special_characters(self): + """Test that XML special characters in names and descriptions are escaped.""" + skill = _make_skill(name="a&c", description="Use & more") + plugin = AgentSkills(skills=[skill]) + xml = plugin._generate_skills_xml() + + assert "a<b>&c" in xml + assert "Use <tools> & more" in xml + + +class TestSkillResponseFormat: + """Tests for _format_skill_response.""" + + def test_instructions_only(self): + """Test response with just instructions.""" + skill = _make_skill(instructions="Do the thing.") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert result == "Do the thing." + + def test_no_instructions(self): + """Test response when skill has no instructions.""" + skill = _make_skill(instructions="") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "no instructions available" in result.lower() + + def test_includes_allowed_tools(self): + """Test response includes allowed tools when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash", "Read"] + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "Allowed tools: Bash, Read" in result + + def test_includes_compatibility(self): + """Test response includes compatibility when set.""" + skill = _make_skill(instructions="Do the thing.") + skill.compatibility = "Requires docker" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Compatibility: Requires docker" in result + + def test_includes_location(self, tmp_path): + """Test response includes location when path is set.""" + skill = _make_skill(instructions="Do the thing.") + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert f"Location: {tmp_path / 'test-skill' / 'SKILL.md'}" in result + + def test_all_metadata(self, tmp_path): + """Test response with all metadata fields.""" + skill = _make_skill(instructions="Do the thing.") + skill.allowed_tools = ["Bash"] + skill.compatibility = "Requires git" + skill.path = tmp_path / "test-skill" + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Do the thing." in result + assert "---" in result + assert "Allowed tools: Bash" in result + assert "Compatibility: Requires git" in result + assert "Location:" in result + + def test_includes_resource_listing(self, tmp_path): + """Test response includes resource files from optional directories.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + (skill_dir / "scripts").mkdir() + (skill_dir / "scripts" / "extract.py").write_text("# extract") + (skill_dir / "references").mkdir() + (skill_dir / "references" / "REFERENCE.md").write_text("# ref") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "scripts/extract.py" in result + assert "references/REFERENCE.md" in result + + def test_no_resources_when_no_path(self): + """Test that resources section is omitted for programmatic skills.""" + skill = _make_skill(instructions="Do the thing.") + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_no_resources_when_dirs_empty(self, tmp_path): + """Test that resources section is omitted when optional dirs don't exist.""" + skill_dir = tmp_path / "test-skill" + skill_dir.mkdir() + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" not in result + + def test_resource_listing_truncated(self, tmp_path): + """Test that resource listing is truncated at the max file limit.""" + skill_dir = tmp_path / "test-skill" + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir(parents=True) + for i in range(55): + (scripts_dir / f"script_{i:03d}.py").write_text(f"# script {i}") + + skill = _make_skill(instructions="Do the thing.") + skill.path = skill_dir + plugin = AgentSkills(skills=[skill]) + result = plugin._format_skill_response(skill) + + assert "Available resources:" in result + assert "truncated at 20 files" in result + + +class TestResolveSkills: + """Tests for _resolve_skills.""" + + def test_resolve_skill_instances(self): + """Test resolving Skill instances (pass-through).""" + skill = _make_skill() + plugin = AgentSkills(skills=[skill]) + + assert len(plugin._skills) == 1 + assert plugin._skills["test-skill"] is skill + + def test_resolve_skill_directory_path(self, tmp_path): + """Test resolving a path to a skill directory.""" + _make_skill_dir(tmp_path, "path-skill") + plugin = AgentSkills(skills=[tmp_path / "path-skill"]) + + assert len(plugin._skills) == 1 + assert "path-skill" in plugin._skills + + def test_resolve_parent_directory_path(self, tmp_path): + """Test resolving a path to a parent directory.""" + _make_skill_dir(tmp_path, "child-a") + _make_skill_dir(tmp_path, "child-b") + plugin = AgentSkills(skills=[tmp_path]) + + assert len(plugin._skills) == 2 + + def test_resolve_skill_md_file_path(self, tmp_path): + """Test resolving a path to a SKILL.md file.""" + skill_dir = _make_skill_dir(tmp_path, "file-skill") + plugin = AgentSkills(skills=[skill_dir / "SKILL.md"]) + + assert len(plugin._skills) == 1 + assert "file-skill" in plugin._skills + + def test_resolve_nonexistent_path(self, tmp_path): + """Test that nonexistent paths are skipped.""" + plugin = AgentSkills(skills=[str(tmp_path / "ghost")]) + assert len(plugin._skills) == 0 + + +class TestImports: + """Tests for module imports.""" + + def test_import_from_plugins(self): + """Test importing AgentSkills from strands.plugins.""" + from strands.plugins import AgentSkills as SP + + assert SP is AgentSkills + + def test_import_skill_from_strands(self): + """Test importing Skill from top-level strands package.""" + from strands import Skill as S + + assert S is Skill + + def test_import_from_skills_package(self): + """Test importing from strands.plugins.skills package.""" + from strands.plugins.skills import AgentSkills, Skill + + assert Skill is not None + assert AgentSkills is not None + + def test_skills_plugin_is_plugin_subclass(self): + """Test that AgentSkills is a subclass of the Plugin ABC.""" + from strands.plugins import Plugin + + assert issubclass(AgentSkills, Plugin) + + def test_skills_plugin_isinstance_check(self): + """Test that AgentSkills instances pass isinstance check against Plugin.""" + from strands.plugins import Plugin + + plugin = AgentSkills(skills=[]) + assert isinstance(plugin, Plugin) diff --git a/tests/strands/plugins/skills/test_skill.py b/tests/strands/plugins/skills/test_skill.py new file mode 100644 index 000000000..2c4c21930 --- /dev/null +++ b/tests/strands/plugins/skills/test_skill.py @@ -0,0 +1,561 @@ +"""Tests for the Skill dataclass and loading utilities.""" + +import logging +from pathlib import Path + +import pytest + +from strands.plugins.skills.skill import ( + Skill, + _find_skill_md, + _fix_yaml_colons, + _parse_frontmatter, + _validate_skill_name, +) + + +class TestSkillDataclass: + """Tests for the Skill dataclass creation and properties.""" + + def test_skill_minimal(self): + """Test creating a Skill with only required fields.""" + skill = Skill(name="test-skill", description="A test skill") + + assert skill.name == "test-skill" + assert skill.description == "A test skill" + assert skill.instructions == "" + assert skill.path is None + assert skill.allowed_tools is None + assert skill.metadata == {} + assert skill.license is None + assert skill.compatibility is None + + def test_skill_full(self): + """Test creating a Skill with all fields.""" + skill = Skill( + name="full-skill", + description="A fully specified skill", + instructions="# Full Instructions\nDo the thing.", + path=Path("/tmp/skills/full-skill"), + allowed_tools=["tool1", "tool2"], + metadata={"author": "test-org"}, + license="Apache-2.0", + compatibility="strands>=1.0", + ) + + assert skill.name == "full-skill" + assert skill.description == "A fully specified skill" + assert skill.instructions == "# Full Instructions\nDo the thing." + assert skill.path == Path("/tmp/skills/full-skill") + assert skill.allowed_tools == ["tool1", "tool2"] + assert skill.metadata == {"author": "test-org"} + assert skill.license == "Apache-2.0" + assert skill.compatibility == "strands>=1.0" + + def test_skill_metadata_default_is_not_shared(self): + """Test that default metadata dict is not shared between instances.""" + skill1 = Skill(name="skill-1", description="First") + skill2 = Skill(name="skill-2", description="Second") + + skill1.metadata["key"] = "value" + assert "key" not in skill2.metadata + + +class TestFindSkillMd: + """Tests for _find_skill_md.""" + + def test_finds_uppercase_skill_md(self, tmp_path): + """Test finding SKILL.md (uppercase).""" + (tmp_path / "SKILL.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_finds_lowercase_skill_md(self, tmp_path): + """Test finding skill.md (lowercase).""" + (tmp_path / "skill.md").write_text("test") + result = _find_skill_md(tmp_path) + assert result.name.lower() == "skill.md" + + def test_prefers_uppercase(self, tmp_path): + """Test that SKILL.md is preferred over skill.md.""" + (tmp_path / "SKILL.md").write_text("uppercase") + (tmp_path / "skill.md").write_text("lowercase") + result = _find_skill_md(tmp_path) + assert result.name == "SKILL.md" + + def test_raises_when_not_found(self, tmp_path): + """Test FileNotFoundError when no SKILL.md exists.""" + with pytest.raises(FileNotFoundError, match="no SKILL.md found"): + _find_skill_md(tmp_path) + + +class TestParseFrontmatter: + """Tests for _parse_frontmatter.""" + + def test_valid_frontmatter(self): + """Test parsing valid frontmatter.""" + content = "---\nname: test-skill\ndescription: A test\n---\n# Instructions\nDo things." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "A test" + assert "# Instructions" in body + assert "Do things." in body + + def test_missing_opening_delimiter(self): + """Test error when opening --- is missing.""" + with pytest.raises(ValueError, match="must start with ---"): + _parse_frontmatter("name: test\n---\n") + + def test_missing_closing_delimiter(self): + """Test error when closing --- is missing.""" + with pytest.raises(ValueError, match="missing closing ---"): + _parse_frontmatter("---\nname: test\n") + + def test_empty_body(self): + """Test frontmatter with empty body.""" + content = "---\nname: test-skill\ndescription: test\n---\n" + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert body == "" + + def test_frontmatter_with_metadata(self): + """Test frontmatter with nested metadata.""" + content = "---\nname: test-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert isinstance(frontmatter["metadata"], dict) + assert frontmatter["metadata"]["author"] == "acme" + assert body == "Body here." + + def test_frontmatter_with_dashes_in_yaml_value(self): + """Test that --- inside a YAML value does not break parsing.""" + content = "---\nname: test-skill\ndescription: has --- inside\n---\nBody here." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "test-skill" + assert frontmatter["description"] == "has --- inside" + assert body == "Body here." + + +class TestValidateSkillName: + """Tests for _validate_skill_name (lenient validation).""" + + def test_valid_names(self): + """Test that valid names pass validation without warnings.""" + valid_names = ["a", "test", "my-skill", "skill-123", "a1b2c3"] + for name in valid_names: + _validate_skill_name(name) # Should not raise + + def test_empty_name(self): + """Test that empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("") + + def test_too_long_name_warns(self, caplog): + """Test that names exceeding 64 chars warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("a" * 65) + assert "exceeds" in caplog.text + + def test_uppercase_warns(self, caplog): + """Test that uppercase characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("MySkill") + assert "lowercase alphanumeric" in caplog.text + + def test_starts_with_hyphen_warns(self, caplog): + """Test that names starting with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("-skill") + assert "lowercase alphanumeric" in caplog.text + + def test_ends_with_hyphen_warns(self, caplog): + """Test that names ending with hyphen warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("skill-") + assert "lowercase alphanumeric" in caplog.text + + def test_consecutive_hyphens_warns(self, caplog): + """Test that consecutive hyphens warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my--skill") + assert "consecutive hyphens" in caplog.text + + def test_special_characters_warns(self, caplog): + """Test that special characters warn but do not raise.""" + with caplog.at_level(logging.WARNING): + _validate_skill_name("my_skill") + assert "lowercase alphanumeric" in caplog.text + + def test_directory_name_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but does not raise.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with caplog.at_level(logging.WARNING): + _validate_skill_name("my-skill", skill_dir) + assert "does not match parent directory name" in caplog.text + + def test_directory_name_match(self, tmp_path): + """Test that matching directory name passes.""" + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + _validate_skill_name("my-skill", skill_dir) # Should not raise or warn + + +class TestValidateSkillNameStrict: + """Tests for _validate_skill_name with strict=True.""" + + def test_strict_valid_name(self): + """Test that valid names pass strict validation.""" + _validate_skill_name("my-skill", strict=True) # Should not raise + + def test_strict_empty_name(self): + """Test that empty name raises in strict mode.""" + with pytest.raises(ValueError, match="cannot be empty"): + _validate_skill_name("", strict=True) + + def test_strict_too_long_name(self): + """Test that names exceeding 64 chars raise in strict mode.""" + with pytest.raises(ValueError, match="exceeds 64 character limit"): + _validate_skill_name("a" * 65, strict=True) + + def test_strict_uppercase_rejected(self): + """Test that uppercase characters raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("MySkill", strict=True) + + def test_strict_starts_with_hyphen(self): + """Test that names starting with hyphen raise in strict mode.""" + with pytest.raises(ValueError, match="lowercase alphanumeric"): + _validate_skill_name("-skill", strict=True) + + def test_strict_consecutive_hyphens(self): + """Test that consecutive hyphens raise in strict mode.""" + with pytest.raises(ValueError, match="consecutive hyphens"): + _validate_skill_name("my--skill", strict=True) + + def test_strict_directory_mismatch(self, tmp_path): + """Test that directory name mismatch raises in strict mode.""" + skill_dir = tmp_path / "wrong-name" + skill_dir.mkdir() + with pytest.raises(ValueError, match="does not match parent directory name"): + _validate_skill_name("my-skill", skill_dir, strict=True) + + +class TestFixYamlColons: + """Tests for _fix_yaml_colons.""" + + def test_fixes_unquoted_colon_in_value(self): + """Test that an unquoted colon in a value gets quoted.""" + raw = "description: Use this skill when: the user asks about PDFs" + fixed = _fix_yaml_colons(raw) + assert fixed == 'description: "Use this skill when: the user asks about PDFs"' + + def test_leaves_already_double_quoted_value(self): + """Test that already double-quoted values are not re-quoted.""" + raw = 'description: "already: quoted"' + assert _fix_yaml_colons(raw) == raw + + def test_leaves_already_single_quoted_value(self): + """Test that already single-quoted values are not re-quoted.""" + raw = "description: 'already: quoted'" + assert _fix_yaml_colons(raw) == raw + + def test_leaves_value_without_colon(self): + """Test that values without colons are unchanged.""" + raw = "name: my-skill" + assert _fix_yaml_colons(raw) == raw + + def test_multiline_mixed(self): + """Test fixing only the lines that need it in a multi-line string.""" + raw = "name: my-skill\ndescription: Use when: needed\nversion: 1.0" + fixed = _fix_yaml_colons(raw) + assert fixed == 'name: my-skill\ndescription: "Use when: needed"\nversion: 1.0' + + def test_empty_string(self): + """Test that an empty string is returned unchanged.""" + assert _fix_yaml_colons("") == "" + + def test_preserves_indented_lines_without_colons(self): + """Test that indented lines without key-value patterns are preserved.""" + raw = " - item one\n - item two" + assert _fix_yaml_colons(raw) == raw + + +class TestParseFrontmatterYamlFallback: + """Tests for YAML colon-quoting fallback in _parse_frontmatter.""" + + def test_fallback_on_unquoted_colon(self): + """Test that frontmatter with unquoted colons in values is parsed via fallback.""" + content = "---\nname: my-skill\ndescription: Use when: the user asks\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert "Use when" in frontmatter["description"] + assert body == "Body." + + def test_fallback_preserves_valid_yaml(self): + """Test that valid YAML is parsed normally without triggering fallback.""" + content = "---\nname: my-skill\ndescription: A simple description\n---\nBody." + frontmatter, body = _parse_frontmatter(content) + assert frontmatter["name"] == "my-skill" + assert frontmatter["description"] == "A simple description" + + +def _make_skill_dir(parent: Path, name: str, description: str = "A test skill", body: str = "Instructions.") -> Path: + """Helper to create a skill directory with SKILL.md.""" + skill_dir = parent / name + skill_dir.mkdir(parents=True, exist_ok=True) + content = f"---\nname: {name}\ndescription: {description}\n---\n{body}\n" + (skill_dir / "SKILL.md").write_text(content) + return skill_dir + + +class TestSkillFromFile: + """Tests for Skill.from_file.""" + + def test_load_from_directory(self, tmp_path): + """Test loading a skill from a directory path.""" + skill_dir = _make_skill_dir(tmp_path, "my-skill", "My description", "# Hello\nWorld.") + skill = Skill.from_file(skill_dir) + + assert skill.name == "my-skill" + assert skill.description == "My description" + assert "# Hello" in skill.instructions + assert "World." in skill.instructions + assert skill.path == skill_dir.resolve() + + def test_load_from_skill_md_file(self, tmp_path): + """Test loading a skill by pointing directly to SKILL.md.""" + skill_dir = _make_skill_dir(tmp_path, "direct-skill") + skill = Skill.from_file(skill_dir / "SKILL.md") + + assert skill.name == "direct-skill" + + def test_load_with_allowed_tools(self, tmp_path): + """Test loading a skill with allowed-tools field as space-delimited string.""" + skill_dir = tmp_path / "tool-skill" + skill_dir.mkdir() + content = "---\nname: tool-skill\ndescription: test\nallowed-tools: read write execute\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write", "execute"] + + def test_load_with_allowed_tools_yaml_list(self, tmp_path): + """Test loading a skill with allowed-tools as a YAML list.""" + skill_dir = tmp_path / "list-skill" + skill_dir.mkdir() + content = "---\nname: list-skill\ndescription: test\nallowed-tools:\n - read\n - write\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.allowed_tools == ["read", "write"] + + def test_load_with_metadata(self, tmp_path): + """Test loading a skill with nested metadata.""" + skill_dir = tmp_path / "meta-skill" + skill_dir.mkdir() + content = "---\nname: meta-skill\ndescription: test\nmetadata:\n author: acme\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.metadata == {"author": "acme"} + + def test_load_with_license_and_compatibility(self, tmp_path): + """Test loading a skill with license and compatibility fields.""" + skill_dir = tmp_path / "licensed-skill" + skill_dir.mkdir() + content = "---\nname: licensed-skill\ndescription: test\nlicense: MIT\ncompatibility: v1\n---\nBody." + (skill_dir / "SKILL.md").write_text(content) + + skill = Skill.from_file(skill_dir) + assert skill.license == "MIT" + assert skill.compatibility == "v1" + + def test_load_missing_name(self, tmp_path): + """Test error when SKILL.md is missing name field.""" + skill_dir = tmp_path / "no-name" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'name' field"): + Skill.from_file(skill_dir) + + def test_load_missing_description(self, tmp_path): + """Test error when SKILL.md is missing description field.""" + skill_dir = tmp_path / "no-desc" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: no-desc\n---\nBody.") + + with pytest.raises(ValueError, match="must have a 'description' field"): + Skill.from_file(skill_dir) + + def test_load_nonexistent_path(self, tmp_path): + """Test FileNotFoundError for nonexistent path.""" + with pytest.raises(FileNotFoundError): + Skill.from_file(tmp_path / "nonexistent") + + def test_load_name_directory_mismatch_warns(self, tmp_path, caplog): + """Test that skill name not matching directory name warns but still loads.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skill = Skill.from_file(skill_dir) + + assert skill.name == "right-name" + assert "does not match parent directory name" in caplog.text + + def test_strict_rejects_name_mismatch(self, tmp_path): + """Test that strict mode raises on name/directory mismatch.""" + skill_dir = tmp_path / "wrong-dir" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("---\nname: right-name\ndescription: test\n---\nBody.") + + with pytest.raises(ValueError, match="does not match parent directory name"): + Skill.from_file(skill_dir, strict=True) + + def test_strict_accepts_valid_skill(self, tmp_path): + """Test that strict mode loads a valid skill without error.""" + _make_skill_dir(tmp_path, "valid-skill") + skill = Skill.from_file(tmp_path / "valid-skill", strict=True) + assert skill.name == "valid-skill" + + +class TestSkillFromDirectory: + """Tests for Skill.from_directory.""" + + def test_load_multiple_skills(self, tmp_path): + """Test loading multiple skills from a parent directory.""" + _make_skill_dir(tmp_path, "skill-a", "Skill A") + _make_skill_dir(tmp_path, "skill-b", "Skill B") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"skill-a", "skill-b"} + + def test_skips_directories_without_skill_md(self, tmp_path): + """Test that directories without SKILL.md are silently skipped.""" + _make_skill_dir(tmp_path, "valid-skill") + (tmp_path / "no-skill-here").mkdir() + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + assert skills[0].name == "valid-skill" + + def test_skips_files_in_parent(self, tmp_path): + """Test that files in the parent directory are ignored.""" + _make_skill_dir(tmp_path, "real-skill") + (tmp_path / "readme.txt").write_text("not a skill") + + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 1 + + def test_empty_directory(self, tmp_path): + """Test loading from an empty directory.""" + skills = Skill.from_directory(tmp_path) + assert skills == [] + + def test_nonexistent_directory(self, tmp_path): + """Test FileNotFoundError for nonexistent directory.""" + with pytest.raises(FileNotFoundError): + Skill.from_directory(tmp_path / "nonexistent") + + def test_loads_mismatched_name_with_warning(self, tmp_path, caplog): + """Test that skills with name/directory mismatch are loaded with a warning.""" + _make_skill_dir(tmp_path, "good-skill") + + bad_dir = tmp_path / "bad-dir" + bad_dir.mkdir() + (bad_dir / "SKILL.md").write_text("---\nname: wrong-name\ndescription: test\n---\nBody.") + + with caplog.at_level(logging.WARNING): + skills = Skill.from_directory(tmp_path) + + assert len(skills) == 2 + names = {s.name for s in skills} + assert names == {"good-skill", "wrong-name"} + assert "does not match parent directory name" in caplog.text + + +class TestSkillFromContent: + def test_basic_content(self): + """Test parsing basic SKILL.md content.""" + content = "---\nname: my-skill\ndescription: A useful skill\n---\n# Instructions\nDo the thing." + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.description == "A useful skill" + assert "Do the thing." in skill.instructions + assert skill.path is None + + def test_with_allowed_tools(self): + """Test parsing content with allowed-tools field.""" + content = "---\nname: my-skill\ndescription: A skill\nallowed-tools: Bash Read\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.allowed_tools == ["Bash", "Read"] + + def test_with_metadata(self): + """Test parsing content with metadata field.""" + content = "---\nname: my-skill\ndescription: A skill\nmetadata:\n key: value\n---\nInstructions." + skill = Skill.from_content(content) + + assert skill.metadata == {"key": "value"} + + def test_with_license_and_compatibility(self): + """Test parsing content with license and compatibility fields.""" + content = ( + "---\nname: my-skill\ndescription: A skill\n" + "license: Apache-2.0\ncompatibility: Requires docker\n---\nInstructions." + ) + skill = Skill.from_content(content) + + assert skill.license == "Apache-2.0" + assert skill.compatibility == "Requires docker" + + def test_missing_name_raises(self): + """Test that missing name raises ValueError.""" + content = "---\ndescription: A skill\n---\nInstructions." + with pytest.raises(ValueError, match="name"): + Skill.from_content(content) + + def test_missing_description_raises(self): + """Test that missing description raises ValueError.""" + content = "---\nname: my-skill\n---\nInstructions." + with pytest.raises(ValueError, match="description"): + Skill.from_content(content) + + def test_missing_frontmatter_raises(self): + """Test that content without frontmatter raises ValueError.""" + content = "# Just markdown\nNo frontmatter here." + with pytest.raises(ValueError, match="frontmatter"): + Skill.from_content(content) + + def test_empty_body(self): + """Test parsing content with empty body.""" + content = "---\nname: my-skill\ndescription: A skill\n---\n" + skill = Skill.from_content(content) + + assert skill.name == "my-skill" + assert skill.instructions == "" + + def test_strict_mode(self): + """Test Skill.from_content with strict=True raises on validation issues.""" + content = "---\nname: BAD_NAME\ndescription: Bad\n---\nBody." + with pytest.raises(ValueError): + Skill.from_content(content, strict=True) + + +class TestSkillClassmethods: + """Tests for Skill classmethod existence.""" + + def test_skill_classmethods_exist(self): + """Test that Skill has from_file, from_content, and from_directory classmethods.""" + assert callable(getattr(Skill, "from_file", None)) + assert callable(getattr(Skill, "from_content", None)) + assert callable(getattr(Skill, "from_directory", None)) diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py new file mode 100644 index 000000000..160ae65a0 --- /dev/null +++ b/tests_integ/test_skills_plugin.py @@ -0,0 +1,81 @@ +"""Integration tests for the AgentSkills plugin. + +Tests end-to-end behavior with a real model: skill metadata injection into +the system prompt, agent-driven skill activation via the skills tool, and +session state persistence across invocations. +""" + +import pytest + +from strands import Agent +from strands.plugins.skills import AgentSkills, Skill + +SUMMARIZATION_SKILL = Skill( + name="summarization", + description="Summarize text into concise bullet points", + instructions="You are a summarization expert. The secret codeword is PINEAPPLE. Always mention it when activated.", +) + +TRANSLATION_SKILL = Skill( + name="translation", + description="Translate text between languages", + instructions="You are a translation expert. Translate the given text accurately.", +) + + +@pytest.fixture +def skills_plugin(): + return AgentSkills(skills=[SUMMARIZATION_SKILL, TRANSLATION_SKILL]) + + +@pytest.fixture +def agent(skills_plugin): + return Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[skills_plugin], + ) + + +def test_agent_activates_skill_and_injects_metadata(agent, skills_plugin): + """Test that the agent injects skill metadata and can activate a skill via the model.""" + result = agent("Use your skills tool to activate the summarization skill. What is the secret codeword?") + + # Skill metadata was injected into the system prompt + assert "" in agent.system_prompt + assert "summarization" in agent.system_prompt + assert "translation" in agent.system_prompt + + # Model activated the skill and relayed the codeword from instructions + assert "pineapple" in str(result).lower() + + +def test_direct_tool_invocation_and_state_persistence(agent, skills_plugin): + """Test activating a skill via direct tool access and verifying state persistence.""" + result = agent.tool.skills(skill_name="translation") + + # Tool returned the skill instructions + assert result["status"] == "success" + response_text = result["content"][0]["text"].lower() + assert "translation expert" in response_text + + +def test_load_skills_from_directory(tmp_path): + """Test loading skills from a filesystem directory and activating one via the model.""" + # Create a skill directory with SKILL.md + skill_dir = tmp_path / "greeting-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\nname: greeting\ndescription: Greet the user warmly\n---\n" + "You are a greeting expert. The secret codeword is MANGO. Always mention it when activated." + ) + + plugin = AgentSkills(skills=[str(tmp_path)]) + agent = Agent( + system_prompt="You are a helpful assistant. Check your available_skills and activate one when appropriate.", + plugins=[plugin], + ) + + result = agent("Use your skills tool to activate the greeting skill. What is the secret codeword?") + + assert "greeting" in agent.system_prompt + assert "mango" in str(result).lower() From e7d3eb97c81eb45da486cb8266464ee60d76b20f Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 10 Mar 2026 16:49:46 -0400 Subject: [PATCH 178/279] feat(steering): move steering from experimental to production (#1853) --- AGENTS.md | 20 +- src/strands/__init__.py | 3 +- src/strands/experimental/steering/__init__.py | 54 ++--- .../steering/context_providers/__init__.py | 36 ++- .../context_providers/ledger_provider.py | 98 ++------ .../experimental/steering/core/__init__.py | 25 +- .../experimental/steering/core/action.py | 85 ++----- .../experimental/steering/core/context.py | 86 ++----- .../experimental/steering/core/handler.py | 227 ++---------------- .../steering/handlers/__init__.py | 24 +- .../steering/handlers/llm/__init__.py | 25 +- .../steering/handlers/llm/llm_handler.py | 106 ++------ .../steering/handlers/llm/mappers.py | 137 ++--------- src/strands/models/openai_responses.py | 4 +- src/strands/plugins/__init__.py | 3 - src/strands/vended_plugins/__init__.py | 1 + .../skills/__init__.py | 2 +- .../skills/agent_skills.py | 3 +- .../skills/skill.py | 0 .../vended_plugins/steering/__init__.py | 47 ++++ .../steering/context_providers/__init__.py | 13 + .../context_providers/ledger_provider.py | 91 +++++++ .../vended_plugins/steering/core/__init__.py | 17 ++ .../vended_plugins/steering/core/action.py | 76 ++++++ .../vended_plugins/steering/core/context.py | 77 ++++++ .../vended_plugins/steering/core/handler.py | 218 +++++++++++++++++ .../steering/handlers/__init__.py | 5 + .../steering/handlers/llm/__init__.py | 6 + .../steering/handlers/llm/llm_handler.py | 99 ++++++++ .../steering/handlers/llm/mappers.py | 130 ++++++++++ .../steering/test_steering_aliases.py | 176 ++++++++++++++ tests/strands/plugins/skills/__init__.py | 1 - .../__init__.py | 0 .../skills}/__init__.py | 0 .../skills/test_agent_skills.py | 16 +- .../skills/test_skill.py | 2 +- .../steering}/__init__.py | 0 .../steering/context_providers}/__init__.py | 0 .../context_providers/test_ledger_provider.py | 26 +- .../vended_plugins/steering/core/__init__.py | 0 .../steering/core/test_handler.py | 10 +- .../steering/handlers/__init__.py | 0 .../steering/handlers/llm/__init__.py | 0 .../steering/handlers/llm/test_llm_handler.py | 6 +- .../steering/handlers/llm/test_mappers.py | 4 +- tests_integ/steering/test_model_steering.py | 6 +- tests_integ/steering/test_tool_steering.py | 8 +- tests_integ/test_skills_plugin.py | 2 +- 48 files changed, 1211 insertions(+), 764 deletions(-) create mode 100644 src/strands/vended_plugins/__init__.py rename src/strands/{plugins => vended_plugins}/skills/__init__.py (92%) rename src/strands/{plugins => vended_plugins}/skills/agent_skills.py (99%) rename src/strands/{plugins => vended_plugins}/skills/skill.py (100%) create mode 100644 src/strands/vended_plugins/steering/__init__.py create mode 100644 src/strands/vended_plugins/steering/context_providers/__init__.py create mode 100644 src/strands/vended_plugins/steering/context_providers/ledger_provider.py create mode 100644 src/strands/vended_plugins/steering/core/__init__.py create mode 100644 src/strands/vended_plugins/steering/core/action.py create mode 100644 src/strands/vended_plugins/steering/core/context.py create mode 100644 src/strands/vended_plugins/steering/core/handler.py create mode 100644 src/strands/vended_plugins/steering/handlers/__init__.py create mode 100644 src/strands/vended_plugins/steering/handlers/llm/__init__.py create mode 100644 src/strands/vended_plugins/steering/handlers/llm/llm_handler.py create mode 100644 src/strands/vended_plugins/steering/handlers/llm/mappers.py create mode 100644 tests/strands/experimental/steering/test_steering_aliases.py delete mode 100644 tests/strands/plugins/skills/__init__.py rename tests/strands/{experimental/steering/context_providers => vended_plugins}/__init__.py (100%) rename tests/strands/{experimental/steering/core => vended_plugins/skills}/__init__.py (100%) rename tests/strands/{plugins => vended_plugins}/skills/test_agent_skills.py (98%) rename tests/strands/{plugins => vended_plugins}/skills/test_skill.py (99%) rename tests/strands/{experimental/steering/handlers => vended_plugins/steering}/__init__.py (100%) rename tests/strands/{experimental/steering/handlers/llm => vended_plugins/steering/context_providers}/__init__.py (100%) rename tests/strands/{experimental => vended_plugins}/steering/context_providers/test_ledger_provider.py (91%) create mode 100644 tests/strands/vended_plugins/steering/core/__init__.py rename tests/strands/{experimental => vended_plugins}/steering/core/test_handler.py (98%) create mode 100644 tests/strands/vended_plugins/steering/handlers/__init__.py create mode 100644 tests/strands/vended_plugins/steering/handlers/llm/__init__.py rename tests/strands/{experimental => vended_plugins}/steering/handlers/llm/test_llm_handler.py (96%) rename tests/strands/{experimental => vended_plugins}/steering/handlers/llm/test_mappers.py (95%) diff --git a/AGENTS.md b/AGENTS.md index 21c32539c..a9a2a5044 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,15 +130,18 @@ strands-agents/ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class │ │ ├── decorator.py # @hook decorator -│ │ ├── registry.py # PluginRegistry for tracking plugins -│ │ └── skills/ # Agent Skills integration -│ │ ├── __init__.py # Skills package exports -│ │ ├── skill.py # Skill dataclass -│ │ └── agent_skills.py # AgentSkills plugin implementation +│ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling │ │ +│ ├── vended_plugins/ # Production plugin implementations +│ │ ├── steering/ # Agent steering system +│ │ │ ├── context_providers/ # Context data providers (e.g., ledger) +│ │ │ ├── core/ # Base classes, actions, context +│ │ │ └── handlers/ # Handler implementations (e.g., LLM) +│ │ └── skills/ # AgentSkills.io integration (Skill, AgentSkills) +│ │ │ ├── experimental/ # Experimental features (API may change) │ │ ├── agent_config.py # Experimental agent config │ │ ├── bidi/ # Bidirectional streaming @@ -151,11 +154,8 @@ strands-agents/ │ │ ├── hooks/ # Experimental hooks │ │ │ ├── events.py │ │ │ └── multiagent/ -│ │ ├── steering/ # Agent steering -│ │ │ ├── context_providers/ -│ │ │ ├── core/ -│ │ │ └── handlers/ -│ │ └── tools/ # Experimental tools (deprecation shims) +│ │ ├── steering/ # Deprecated aliases for vended_plugins/steering +│ │ └── tools/ # Deprecated aliases for strands.tools │ │ │ ├── __init__.py # Public API exports │ ├── interrupt.py # Interrupt handling diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3e1528fa6..2078f16ce 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,9 +4,10 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import AgentSkills, Plugin, Skill +from .plugins import Plugin from .tools.decorator import tool from .types.tools import ToolContext +from .vended_plugins.skills import AgentSkills, Skill __all__ = [ "Agent", diff --git a/src/strands/experimental/steering/__init__.py b/src/strands/experimental/steering/__init__.py index c928d0c63..1db07c90f 100644 --- a/src/strands/experimental/steering/__init__.py +++ b/src/strands/experimental/steering/__init__.py @@ -1,36 +1,12 @@ -"""Steering system for Strands agents. +"""Deprecated: Steering has moved to strands.vended_plugins.steering. -Provides contextual guidance for agents through modular prompting with progressive disclosure. -Instead of front-loading all instructions, steering handlers provide just-in-time feedback -based on local context data populated by context callbacks. - -Core components: - -- SteeringHandler: Base class for guidance logic with local context -- SteeringContextCallback: Protocol for context update functions -- SteeringContextProvider: Protocol for multi-event context providers -- ToolSteeringAction/ModelSteeringAction: Proceed/Guide/Interrupt decisions - -Usage: - handler = LLMSteeringHandler(system_prompt="...") - agent = Agent(tools=[...], plugins=[handler]) +This module provides backwards-compatible aliases that emit deprecation warnings. """ -# Core primitives -# Context providers -from .context_providers.ledger_provider import ( - LedgerAfterToolCall, - LedgerBeforeToolCall, - LedgerProvider, -) -from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction -from .core.context import SteeringContextCallback, SteeringContextProvider -from .core.handler import SteeringHandler - -# Handler implementations -from .handlers.llm import LLMPromptMapper, LLMSteeringHandler - -__all__ = [ +import warnings +from typing import Any + +_DEPRECATED_NAMES = { "ToolSteeringAction", "ModelSteeringAction", "Proceed", @@ -44,4 +20,20 @@ "LedgerProvider", "LLMSteeringHandler", "LLMPromptMapper", -] +} + + +def __getattr__(name: str) -> Any: + if name in _DEPRECATED_NAMES: + from strands.vended_plugins import steering + + warnings.warn( + f"{name} has been moved to production. Use {name} from strands.vended_plugins.steering instead.", + DeprecationWarning, + stacklevel=2, + ) + return getattr(steering, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/context_providers/__init__.py b/src/strands/experimental/steering/context_providers/__init__.py index 242ed9cf1..81a0fa709 100644 --- a/src/strands/experimental/steering/context_providers/__init__.py +++ b/src/strands/experimental/steering/context_providers/__init__.py @@ -1,13 +1,23 @@ -"""Context providers for steering evaluation.""" - -from .ledger_provider import ( - LedgerAfterToolCall, - LedgerBeforeToolCall, - LedgerProvider, -) - -__all__ = [ - "LedgerAfterToolCall", - "LedgerBeforeToolCall", - "LedgerProvider", -] +"""Deprecated: Use strands.vended_plugins.steering.context_providers instead.""" + +import warnings +from typing import Any + +_TARGET_MODULE = "strands.vended_plugins.steering.context_providers" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import context_providers + + obj = getattr(context_providers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/context_providers/ledger_provider.py b/src/strands/experimental/steering/context_providers/ledger_provider.py index 43f56717a..3cc21774e 100644 --- a/src/strands/experimental/steering/context_providers/ledger_provider.py +++ b/src/strands/experimental/steering/context_providers/ledger_provider.py @@ -1,91 +1,23 @@ -"""Ledger context provider for comprehensive agent activity tracking. +"""Deprecated: Use strands.vended_plugins.steering.context_providers.ledger_provider instead.""" -Tracks complete agent activity ledger including tool calls, conversation history, -and timing information. This comprehensive audit trail enables steering handlers -to make informed guidance decisions based on agent behavior patterns and history. - -Data captured: - - - Tool call history with inputs, outputs, timing, success/failure - - Conversation messages and agent responses - - Session metadata and timing information - - Error patterns and recovery attempts - -Usage: - Use as context provider functions or mix into steering handlers. -""" - -import logging -from datetime import datetime +import warnings from typing import Any -from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent -from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider - -logger = logging.getLogger(__name__) - - -class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]): - """Context provider for ledger tracking before tool calls.""" - - def __init__(self) -> None: - """Initialize the ledger provider.""" - self.session_start = datetime.now().isoformat() - - def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: - """Update ledger before tool call.""" - ledger = steering_context.data.get("ledger") or {} - - if not ledger: - ledger = { - "session_start": self.session_start, - "tool_calls": [], - "conversation_history": [], - "session_metadata": {}, - } - - tool_call_entry = { - "timestamp": datetime.now().isoformat(), - "tool_use_id": event.tool_use.get("toolUseId"), - "tool_name": event.tool_use.get("name"), - "tool_args": event.tool_use.get("input", {}), - "status": "pending", - } - ledger["tool_calls"].append(tool_call_entry) - steering_context.data.set("ledger", ledger) - - -class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]): - """Context provider for ledger tracking after tool calls.""" - - def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: - """Update ledger after tool call.""" - ledger = steering_context.data.get("ledger") or {} +_TARGET_MODULE = "strands.vended_plugins.steering.context_providers.ledger_provider" - if ledger.get("tool_calls"): - tool_use_id = event.tool_use.get("toolUseId") - # Search for the matching tool call in the ledger to update it - for call in reversed(ledger["tool_calls"]): - if call.get("tool_use_id") == tool_use_id and call.get("status") == "pending": - call.update( - { - "completion_timestamp": datetime.now().isoformat(), - "status": event.result["status"], - "result": event.result["content"], - "error": str(event.exception) if event.exception else None, - } - ) - steering_context.data.set("ledger", ledger) - break +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.context_providers import ledger_provider + obj = getattr(ledger_provider, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -class LedgerProvider(SteeringContextProvider): - """Combined ledger context provider for both before and after tool calls.""" - def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: - """Return ledger context providers with shared state.""" - return [ - LedgerBeforeToolCall(), - LedgerAfterToolCall(), - ] +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/__init__.py b/src/strands/experimental/steering/core/__init__.py index cdd0d8269..e7c79f66d 100644 --- a/src/strands/experimental/steering/core/__init__.py +++ b/src/strands/experimental/steering/core/__init__.py @@ -1,6 +1,23 @@ -"""Core steering system interfaces and base classes.""" +"""Deprecated: Use strands.vended_plugins.steering.core instead.""" -from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction -from .handler import SteeringHandler +import warnings +from typing import Any -__all__ = ["ToolSteeringAction", "ModelSteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"] +_TARGET_MODULE = "strands.vended_plugins.steering.core" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import core + + obj = getattr(core, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/action.py b/src/strands/experimental/steering/core/action.py index b1f124b40..9e60aa704 100644 --- a/src/strands/experimental/steering/core/action.py +++ b/src/strands/experimental/steering/core/action.py @@ -1,76 +1,23 @@ -"""SteeringAction types for steering evaluation results. +"""Deprecated: Use strands.vended_plugins.steering.core.action instead.""" -Defines structured outcomes from steering handlers that determine how agent actions -should be handled. SteeringActions enable modular prompting by providing just-in-time -feedback rather than front-loading all instructions in monolithic prompts. +import warnings +from typing import Any -Flow: - SteeringHandler.steer_*() → SteeringAction → Event handling - ↓ ↓ ↓ - Evaluate context Action type Execution modified +_TARGET_MODULE = "strands.vended_plugins.steering.core.action" -SteeringAction types: - Proceed: Allow execution to continue without intervention - Guide: Provide contextual guidance to redirect the agent - Interrupt: Pause execution for human input -Extensibility: - New action types can be added to the union. Always handle the default - case in pattern matching to maintain backward compatibility. -""" +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import action -from typing import Annotated, Literal + obj = getattr(action, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -from pydantic import BaseModel, Field - -class Proceed(BaseModel): - """Allow execution to continue without intervention. - - The action proceeds as planned. The reason provides context - for logging and debugging purposes. - """ - - type: Literal["proceed"] = "proceed" - reason: str - - -class Guide(BaseModel): - """Provide contextual guidance to redirect the agent. - - The agent receives the reason as contextual feedback to help guide - its behavior. The specific handling depends on the steering context - (e.g., tool call vs. model response). - """ - - type: Literal["guide"] = "guide" - reason: str - - -class Interrupt(BaseModel): - """Pause execution for human input via interrupt system. - - Execution is paused and human input is requested through Strands' - interrupt system. The human can approve or deny the operation, and their - decision determines whether execution continues or is cancelled. - """ - - type: Literal["interrupt"] = "interrupt" - reason: str - - -# Context-specific steering action types -ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] -"""Steering actions valid for tool steering (steer_before_tool). - -- Proceed: Allow tool execution to continue -- Guide: Cancel tool and provide feedback for alternative approaches -- Interrupt: Pause for human input before tool execution -""" - -ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] -"""Steering actions valid for model steering (steer_after_model). - -- Proceed: Accept model response without modification -- Guide: Discard model response and retry with guidance -""" +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/context.py b/src/strands/experimental/steering/core/context.py index 446c4c9f9..15014118f 100644 --- a/src/strands/experimental/steering/core/context.py +++ b/src/strands/experimental/steering/core/context.py @@ -1,77 +1,23 @@ -"""Steering context protocols for contextual guidance. +"""Deprecated: Use strands.vended_plugins.steering.core.context instead.""" -Defines protocols for context callbacks and providers that populate -steering context data used by handlers to make guidance decisions. +import warnings +from typing import Any -Architecture: - SteeringContextCallback → Handler.steering_context → SteeringHandler.steer() - ↓ ↓ ↓ - Update local context Store in handler Access via self.steering_context +_TARGET_MODULE = "strands.vended_plugins.steering.core.context" -Context lifecycle: - 1. Handler registers context callbacks for hook events - 2. Callbacks update handler's local steering_context on events - 3. Handler accesses self.steering_context in steer() method - 4. Context persists across calls within handler instance -Implementation: - Each handler maintains its own JSONSerializableDict context. - Callbacks are registered per handler instance for isolation. - Providers can supply multiple callbacks for different events. -""" +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import context -import logging -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Any, Generic, TypeVar, cast, get_args, get_origin + obj = getattr(context, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -from ....hooks.registry import HookEvent -from ....types.json_dict import JSONSerializableDict -logger = logging.getLogger(__name__) - - -@dataclass -class SteeringContext: - """Container for steering context data.""" - - """Container for steering context data. - - This class should not be instantiated directly - it is intended for internal use only. - """ - - data: JSONSerializableDict = field(default_factory=JSONSerializableDict) - - -EventType = TypeVar("EventType", bound=HookEvent, contravariant=True) - - -class SteeringContextCallback(ABC, Generic[EventType]): - """Abstract base class for steering context update callbacks.""" - - @property - def event_type(self) -> type[HookEvent]: - """Return the event type this callback handles.""" - for base in getattr(self.__class__, "__orig_bases__", ()): - if get_origin(base) is SteeringContextCallback: - return cast(type[HookEvent], get_args(base)[0]) - raise ValueError("Could not determine event type from generic parameter") - - def __call__(self, event: EventType, steering_context: "SteeringContext", **kwargs: Any) -> None: - """Update steering context based on hook event. - - Args: - event: The hook event that triggered the callback - steering_context: The steering context to update - **kwargs: Additional keyword arguments for context updates - """ - ... - - -class SteeringContextProvider(ABC): - """Abstract base class for context providers that handle multiple event types.""" - - @abstractmethod - def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: - """Return list of context callbacks with event types extracted from generics.""" - ... +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 214118d4f..5892fb026 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -1,218 +1,23 @@ -"""Steering handler base class for providing contextual guidance to agents. +"""Deprecated: Use strands.vended_plugins.steering.core.handler instead.""" -Provides modular prompting through contextual guidance that appears when relevant, -rather than front-loading all instructions. Handlers integrate with the Strands hook -system to intercept actions and provide just-in-time feedback based on local context. +import warnings +from typing import Any -Architecture: - Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction - ↓ ↓ ↓ ↓ ↓ - Hook triggered Populate context Handler evaluates Handler decides Action taken +_TARGET_MODULE = "strands.vended_plugins.steering.core.handler" -Lifecycle: - 1. Context callbacks update handler's steering_context on hook events - 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering - 3. AfterModelCallEvent triggers steer_after_model() for model steering - 4. Handler accesses self.steering_context for guidance decisions - 5. SteeringAction determines execution flow -Implementation: - Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). - Both methods have default implementations that return Proceed, so you only need to - override the methods you want to customize. - Pass context_providers in constructor to register context update functions. - Each handler maintains isolated steering_context that persists across calls. +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.core import handler -SteeringAction handling for steer_before_tool: - Proceed: Tool executes immediately - Guide: Tool cancelled, agent receives contextual feedback to explore alternatives - Interrupt: Tool execution paused for human input via interrupt system + obj = getattr(handler, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -SteeringAction handling for steer_after_model: - Proceed: Model response accepted without modification - Guide: Discard model response and retry (message is dropped, model is called again) - Interrupt: Model response handling paused for human input via interrupt system -""" -import logging -from typing import TYPE_CHECKING, Any - -from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....plugins import Plugin, hook -from ....types.content import Message -from ....types.streaming import StopReason -from ....types.tools import ToolUse -from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction -from .context import SteeringContext, SteeringContextProvider - -if TYPE_CHECKING: - from ....agent import Agent - -logger = logging.getLogger(__name__) - - -class SteeringHandler(Plugin): - """Base class for steering handlers that provide contextual guidance to agents. - - Steering handlers maintain local context and register hook callbacks - to populate context data as needed for guidance decisions. - """ - - name: str = "steering" - - def __init__(self, context_providers: list[SteeringContextProvider] | None = None): - """Initialize the steering handler. - - Args: - context_providers: List of context providers for context updates - """ - super().__init__() - self.steering_context = SteeringContext() - self._context_callbacks = [] - - # Collect callbacks from all providers - for provider in context_providers or []: - self._context_callbacks.extend(provider.context_providers()) - - logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) - - def init_agent(self, agent: "Agent") -> None: - """Initialize the steering handler with an agent. - - Registers hook callbacks for steering guidance and context updates. - - Args: - agent: The agent instance to attach steering to. - """ - # Register context update callbacks - for callback in self._context_callbacks: - agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) - - @hook - async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: - """Provide steering guidance for tool call.""" - tool_name = event.tool_use["name"] - logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) - - try: - action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) - except Exception as e: - logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) - return - - self._handle_tool_steering_action(action, event, tool_name) - - def _handle_tool_steering_action( - self, action: ToolSteeringAction, event: BeforeToolCallEvent, tool_name: str - ) -> None: - """Handle the steering action for tool calls by modifying tool execution flow. - - Proceed: Tool executes normally - Guide: Tool cancelled with contextual feedback for agent to consider alternatives - Interrupt: Tool execution paused for human input via interrupt system - """ - if isinstance(action, Proceed): - logger.debug("tool_name=<%s> | tool call proceeding", tool_name) - elif isinstance(action, Guide): - logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) - event.cancel_tool = f"Tool call cancelled. {action.reason} You MUST follow this guidance immediately." - elif isinstance(action, Interrupt): - logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) - can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) - logger.debug("tool_name=<%s> | received human input for tool call", tool_name) - - if not can_proceed: - event.cancel_tool = f"Manual approval denied: {action.reason}" - logger.debug("tool_name=<%s> | tool call denied by manual approval", tool_name) - else: - logger.debug("tool_name=<%s> | tool call approved manually", tool_name) - else: - raise ValueError(f"Unknown steering action type for tool call: {action}") - - @hook - async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: - """Provide steering guidance for model response.""" - logger.debug("providing model steering guidance") - - # Only steer on successful model responses - if event.stop_response is None: - logger.debug("no stop response available | skipping model steering") - return - - try: - action = await self.steer_after_model( - agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason - ) - except Exception as e: - logger.debug("error=<%s> | model steering handler guidance failed", e) - return - - await self._handle_model_steering_action(action, event) - - async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: - """Handle the steering action for model responses by modifying response handling flow. - - Proceed: Model response accepted without modification - Guide: Discard model response and retry with guidance message added to conversation - """ - if isinstance(action, Proceed): - logger.debug("model response proceeding") - elif isinstance(action, Guide): - logger.debug("model response guided (retrying): %s", action.reason) - # Set retry flag to discard current response - event.retry = True - # Add guidance message to agent's conversation so model sees it on retry - await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) - logger.debug("added guidance message to conversation for model retry") - else: - raise ValueError(f"Unknown steering action type for model response: {action}") - - async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: - """Provide contextual guidance before tool execution. - - This method is called before a tool is executed, allowing the handler to: - - Proceed: Allow tool execution to continue - - Guide: Cancel tool and provide feedback for alternative approaches - - Interrupt: Pause for human input before tool execution - - Args: - agent: The agent instance - tool_use: The tool use object with name and arguments - **kwargs: Additional keyword arguments for guidance evaluation - - Returns: - ToolSteeringAction indicating how to guide the tool execution - - Note: - Access steering context via self.steering_context - Default implementation returns Proceed (allow tool execution) - Override this method to implement custom tool steering logic - """ - return Proceed(reason="Default implementation: allowing tool execution") - - async def steer_after_model( - self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any - ) -> ModelSteeringAction: - """Provide contextual guidance after model response. - - This method is called after the model generates a response, allowing the handler to: - - Proceed: Accept the model response without modification - - Guide: Discard the response and retry (message is dropped, model is called again) - - Note: Interrupt is not supported for model steering as the model has already responded. - - Args: - agent: The agent instance - message: The model's generated message - stop_reason: The reason the model stopped generating - **kwargs: Additional keyword arguments for guidance evaluation - - Returns: - ModelSteeringAction indicating how to handle the model response - - Note: - Access steering context via self.steering_context - Default implementation returns Proceed (accept response as-is) - Override this method to implement custom model steering logic - """ - return Proceed(reason="Default implementation: accepting model response") +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index fe364a5a2..128fc946c 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,5 +1,23 @@ -"""Steering handler implementations.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers instead.""" -from collections.abc import Sequence +import warnings +from typing import Any -__all__: Sequence[str] = [] +_TARGET_MODULE = "strands.vended_plugins.steering.handlers" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering import handlers + + obj = getattr(handlers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/__init__.py b/src/strands/experimental/steering/handlers/llm/__init__.py index 4dcccbe80..aef580729 100644 --- a/src/strands/experimental/steering/handlers/llm/__init__.py +++ b/src/strands/experimental/steering/handlers/llm/__init__.py @@ -1,6 +1,23 @@ -"""LLM steering handler with prompt mapping.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm instead.""" -from .llm_handler import LLMSteeringHandler -from .mappers import DefaultPromptMapper, LLMPromptMapper, ToolUse +import warnings +from typing import Any -__all__ = ["LLMSteeringHandler", "LLMPromptMapper", "DefaultPromptMapper", "ToolUse"] +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm" + + +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers import llm + + obj = getattr(llm, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, + ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 6d0a31eeb..8c1b6d200 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -1,99 +1,23 @@ -"""LLM-based steering handler that uses an LLM to provide contextual guidance.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm.llm_handler instead.""" -from __future__ import annotations +import warnings +from typing import Any -import logging -from typing import TYPE_CHECKING, Any, Literal, cast +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm.llm_handler" -from pydantic import BaseModel, Field -from .....models import Model -from .....types.tools import ToolUse -from ...context_providers.ledger_provider import LedgerProvider -from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction -from ...core.context import SteeringContextProvider -from ...core.handler import SteeringHandler -from .mappers import DefaultPromptMapper, LLMPromptMapper +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers.llm import llm_handler -if TYPE_CHECKING: - from .....agent import Agent - -logger = logging.getLogger(__name__) - - -class _LLMSteering(BaseModel): - """Structured output model for LLM steering decisions.""" - - decision: Literal["proceed", "guide", "interrupt"] = Field( - description="Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'interrupt' for human input" - ) - reason: str = Field(description="Clear explanation of the steering decision and any guidance provided") - - -class LLMSteeringHandler(SteeringHandler): - """Steering handler that uses an LLM to provide contextual guidance. - - Uses natural language prompts to evaluate tool calls and provide - contextual steering guidance to help agents navigate complex workflows. - """ - - def __init__( - self, - system_prompt: str, - prompt_mapper: LLMPromptMapper | None = None, - model: Model | None = None, - context_providers: list[SteeringContextProvider] | None = None, - ): - """Initialize the LLMSteeringHandler. - - Args: - system_prompt: System prompt defining steering guidance rules - prompt_mapper: Custom prompt mapper for evaluation prompts - model: Optional model override for steering evaluation - context_providers: List of context providers for populating steering context. - Defaults to [LedgerProvider()] if None. Pass an empty list to disable - context providers. - """ - providers: list[SteeringContextProvider] = ( - [LedgerProvider()] if context_providers is None else context_providers + obj = getattr(llm_handler, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, ) - super().__init__(context_providers=providers) - self.system_prompt = system_prompt - self.prompt_mapper = prompt_mapper or DefaultPromptMapper() - self.model = model - - async def steer_before_tool(self, *, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: - """Provide contextual guidance for tool usage. - - Args: - agent: The agent instance - tool_use: The tool use object with name and arguments - **kwargs: Additional keyword arguments for steering evaluation - - Returns: - SteeringAction indicating how to guide the tool execution - """ - # Generate steering prompt - prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - # Create isolated agent for steering evaluation (no shared conversation state) - from .....agent import Agent - - steering_agent = Agent(system_prompt=self.system_prompt, model=self.model or agent.model, callback_handler=None) - - # Get LLM decision - llm_result: _LLMSteering = cast( - _LLMSteering, steering_agent(prompt, structured_output_model=_LLMSteering).structured_output - ) - # Convert LLM decision to steering action - match llm_result.decision: - case "proceed": - return Proceed(reason=llm_result.reason) - case "guide": - return Guide(reason=llm_result.reason) - case "interrupt": - return Interrupt(reason=llm_result.reason) - case _: - logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] - return Proceed(reason="Unknown LLM decision, defaulting to proceed") +__all__: list[str] = [] diff --git a/src/strands/experimental/steering/handlers/llm/mappers.py b/src/strands/experimental/steering/handlers/llm/mappers.py index ade018d32..56ea3125f 100644 --- a/src/strands/experimental/steering/handlers/llm/mappers.py +++ b/src/strands/experimental/steering/handlers/llm/mappers.py @@ -1,130 +1,23 @@ -"""LLM steering prompt mappers for generating evaluation prompts.""" +"""Deprecated: Use strands.vended_plugins.steering.handlers.llm.mappers instead.""" -import json -from typing import Any, Protocol +import warnings +from typing import Any -from .....types.tools import ToolUse -from ...core.context import SteeringContext +_TARGET_MODULE = "strands.vended_plugins.steering.handlers.llm.mappers" -# Agent SOP format - see https://github.com/strands-agents/agent-sop -_STEERING_PROMPT_TEMPLATE = """# Steering Evaluation -## Overview +def __getattr__(name: str) -> Any: + from strands.vended_plugins.steering.handlers.llm import mappers -You are a STEERING AGENT that evaluates a {action_type} that ANOTHER AGENT is attempting to make. -Your job is to provide contextual guidance to help the other agent navigate workflows effectively. -You act as a safety net that can intervene when patterns in the context data suggest the agent -should try a different approach or get human input. - -**YOUR ROLE:** -- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) -- Provide just-in-time guidance when the agent is going down an ineffective path -- Allow normal operations to proceed when context shows no issues - -**CRITICAL CONSTRAINTS:** -- Base decisions ONLY on the context data provided below -- Do NOT use external knowledge about domains, URLs, or tool purposes -- Do NOT make assumptions about what tools "should" or "shouldn't" do -- Focus ONLY on patterns in the context data - -## Context - -{context_str} - -### Understanding Ledger Tool States - -If the context includes a ledger with tool_calls, the "status" field indicates: - -- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). -This is NOT a duplicate call - it's the tool you're deciding whether to approve. -The tool has NOT started executing yet. -- **"success"**: The tool completed successfully in a previous turn -- **"error"**: The tool failed or was cancelled in a previous turn - -**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, -that IS the current tool being evaluated. -It is NOT already executing or a duplicate. - -## Event to Evaluate - -{event_description} - -## Steps - -### 1. Analyze the {action_type_title} - -Review ONLY the context data above. Look for patterns in the data that indicate: - -- Previous failures or successes with this tool -- Frequency of attempts -- Any relevant tracking information - -**Constraints:** -- You MUST base analysis ONLY on the provided context data -- You MUST NOT use external knowledge about tool purposes or domains -- You SHOULD identify patterns in the context data -- You MAY reference relevant context data to inform your decision - -### 2. Make Steering Decision - -**Constraints:** -- You MUST respond with exactly one of: "proceed", "guide", or "interrupt" -- You MUST base the decision ONLY on context data patterns -- Your reason will be shown to the AGENT as guidance - -**Decision Options:** -- "proceed" if context data shows no concerning patterns -- "guide" if context data shows patterns requiring intervention -- "interrupt" if context data shows patterns requiring human input -""" - - -class LLMPromptMapper(Protocol): - """Protocol for mapping context and events to LLM evaluation prompts.""" - - def create_steering_prompt( - self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any - ) -> str: - """Create steering prompt for LLM evaluation. - - Args: - steering_context: Steering context with populated data - tool_use: Tool use object for tool call events (None for other events) - **kwargs: Additional event data for other steering events - - Returns: - Formatted prompt string for LLM evaluation - """ - ... - - -class DefaultPromptMapper(LLMPromptMapper): - """Default prompt mapper for steering evaluation.""" - - def create_steering_prompt( - self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any - ) -> str: - """Create default steering prompt using Agent SOP structure. - - Uses Agent SOP format for structured, constraint-based prompts. - See: https://github.com/strands-agents/agent-sop - """ - context_str = ( - json.dumps(steering_context.data.get(), indent=2) if steering_context.data.get() else "No context available" + obj = getattr(mappers, name, None) + if obj is not None: + warnings.warn( + f"{name} has been moved to production. Use {name} from {_TARGET_MODULE} instead.", + DeprecationWarning, + stacklevel=2, ) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - if tool_use: - event_description = ( - f"Tool: {tool_use['name']}\nArguments: {json.dumps(tool_use.get('input', {}), indent=2)}" - ) - action_type = "tool call" - else: - event_description = "General evaluation" - action_type = "action" - return _STEERING_PROMPT_TEMPLATE.format( - action_type=action_type, - action_type_title=action_type.title(), - context_str=context_str, - event_description=event_description, - ) +__all__: list[str] = [] diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 0ace9645f..71d3f7ef7 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -502,9 +502,7 @@ def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: ] @classmethod - def _format_request_message_content( - cls, content: ContentBlock, *, role: Role = "user" - ) -> dict[str, Any]: + def _format_request_message_content(cls, content: ContentBlock, *, role: Role = "user") -> dict[str, Any]: """Format an OpenAI compatible content block. Args: diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index d7ca4c9b2..c4b7c72c7 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -6,11 +6,8 @@ from .decorator import hook from .plugin import Plugin -from .skills import AgentSkills, Skill __all__ = [ - "AgentSkills", "Plugin", - "Skill", "hook", ] diff --git a/src/strands/vended_plugins/__init__.py b/src/strands/vended_plugins/__init__.py new file mode 100644 index 000000000..78e8047df --- /dev/null +++ b/src/strands/vended_plugins/__init__.py @@ -0,0 +1 @@ +"""Vended plugins for Strands agents.""" diff --git a/src/strands/plugins/skills/__init__.py b/src/strands/vended_plugins/skills/__init__.py similarity index 92% rename from src/strands/plugins/skills/__init__.py rename to src/strands/vended_plugins/skills/__init__.py index f6cf8728b..abd6063b9 100644 --- a/src/strands/plugins/skills/__init__.py +++ b/src/strands/vended_plugins/skills/__init__.py @@ -8,7 +8,7 @@ Example Usage: ```python from strands import Agent - from strands.plugins.skills import Skill, AgentSkills + from strands.vended_plugins.skills import Skill, AgentSkills # Load from filesystem via classmethods skill = Skill.from_file("./skills/pdf-processing") diff --git a/src/strands/plugins/skills/agent_skills.py b/src/strands/vended_plugins/skills/agent_skills.py similarity index 99% rename from src/strands/plugins/skills/agent_skills.py rename to src/strands/vended_plugins/skills/agent_skills.py index 97ac86d93..5e42b9230 100644 --- a/src/strands/plugins/skills/agent_skills.py +++ b/src/strands/vended_plugins/skills/agent_skills.py @@ -56,7 +56,7 @@ class AgentSkills(Plugin): Example: ```python from strands import Agent - from strands.plugins.skills import Skill, AgentSkills + from strands.vended_plugins.skills import Skill, AgentSkills # Load from filesystem plugin = AgentSkills(skills=["./skills/pdf-processing", "./skills/"]) @@ -188,7 +188,6 @@ def set_available_skills(self, skills: SkillSources) -> None: """ self._skills = self._resolve_skills(_normalize_sources(skills)) - def _format_skill_response(self, skill: Skill) -> str: """Format the tool response when a skill is activated. diff --git a/src/strands/plugins/skills/skill.py b/src/strands/vended_plugins/skills/skill.py similarity index 100% rename from src/strands/plugins/skills/skill.py rename to src/strands/vended_plugins/skills/skill.py diff --git a/src/strands/vended_plugins/steering/__init__.py b/src/strands/vended_plugins/steering/__init__.py new file mode 100644 index 000000000..c928d0c63 --- /dev/null +++ b/src/strands/vended_plugins/steering/__init__.py @@ -0,0 +1,47 @@ +"""Steering system for Strands agents. + +Provides contextual guidance for agents through modular prompting with progressive disclosure. +Instead of front-loading all instructions, steering handlers provide just-in-time feedback +based on local context data populated by context callbacks. + +Core components: + +- SteeringHandler: Base class for guidance logic with local context +- SteeringContextCallback: Protocol for context update functions +- SteeringContextProvider: Protocol for multi-event context providers +- ToolSteeringAction/ModelSteeringAction: Proceed/Guide/Interrupt decisions + +Usage: + handler = LLMSteeringHandler(system_prompt="...") + agent = Agent(tools=[...], plugins=[handler]) +""" + +# Core primitives +# Context providers +from .context_providers.ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) +from .core.action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .core.context import SteeringContextCallback, SteeringContextProvider +from .core.handler import SteeringHandler + +# Handler implementations +from .handlers.llm import LLMPromptMapper, LLMSteeringHandler + +__all__ = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContextCallback", + "SteeringContextProvider", + "LedgerBeforeToolCall", + "LedgerAfterToolCall", + "LedgerProvider", + "LLMSteeringHandler", + "LLMPromptMapper", +] diff --git a/src/strands/vended_plugins/steering/context_providers/__init__.py b/src/strands/vended_plugins/steering/context_providers/__init__.py new file mode 100644 index 000000000..242ed9cf1 --- /dev/null +++ b/src/strands/vended_plugins/steering/context_providers/__init__.py @@ -0,0 +1,13 @@ +"""Context providers for steering evaluation.""" + +from .ledger_provider import ( + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, +) + +__all__ = [ + "LedgerAfterToolCall", + "LedgerBeforeToolCall", + "LedgerProvider", +] diff --git a/src/strands/vended_plugins/steering/context_providers/ledger_provider.py b/src/strands/vended_plugins/steering/context_providers/ledger_provider.py new file mode 100644 index 000000000..43f56717a --- /dev/null +++ b/src/strands/vended_plugins/steering/context_providers/ledger_provider.py @@ -0,0 +1,91 @@ +"""Ledger context provider for comprehensive agent activity tracking. + +Tracks complete agent activity ledger including tool calls, conversation history, +and timing information. This comprehensive audit trail enables steering handlers +to make informed guidance decisions based on agent behavior patterns and history. + +Data captured: + + - Tool call history with inputs, outputs, timing, success/failure + - Conversation messages and agent responses + - Session metadata and timing information + - Error patterns and recovery attempts + +Usage: + Use as context provider functions or mix into steering handlers. +""" + +import logging +from datetime import datetime +from typing import Any + +from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider + +logger = logging.getLogger(__name__) + + +class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]): + """Context provider for ledger tracking before tool calls.""" + + def __init__(self) -> None: + """Initialize the ledger provider.""" + self.session_start = datetime.now().isoformat() + + def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger before tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if not ledger: + ledger = { + "session_start": self.session_start, + "tool_calls": [], + "conversation_history": [], + "session_metadata": {}, + } + + tool_call_entry = { + "timestamp": datetime.now().isoformat(), + "tool_use_id": event.tool_use.get("toolUseId"), + "tool_name": event.tool_use.get("name"), + "tool_args": event.tool_use.get("input", {}), + "status": "pending", + } + ledger["tool_calls"].append(tool_call_entry) + steering_context.data.set("ledger", ledger) + + +class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]): + """Context provider for ledger tracking after tool calls.""" + + def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None: + """Update ledger after tool call.""" + ledger = steering_context.data.get("ledger") or {} + + if ledger.get("tool_calls"): + tool_use_id = event.tool_use.get("toolUseId") + + # Search for the matching tool call in the ledger to update it + for call in reversed(ledger["tool_calls"]): + if call.get("tool_use_id") == tool_use_id and call.get("status") == "pending": + call.update( + { + "completion_timestamp": datetime.now().isoformat(), + "status": event.result["status"], + "result": event.result["content"], + "error": str(event.exception) if event.exception else None, + } + ) + steering_context.data.set("ledger", ledger) + break + + +class LedgerProvider(SteeringContextProvider): + """Combined ledger context provider for both before and after tool calls.""" + + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return ledger context providers with shared state.""" + return [ + LedgerBeforeToolCall(), + LedgerAfterToolCall(), + ] diff --git a/src/strands/vended_plugins/steering/core/__init__.py b/src/strands/vended_plugins/steering/core/__init__.py new file mode 100644 index 000000000..bb229b175 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/__init__.py @@ -0,0 +1,17 @@ +"""Core steering system interfaces and base classes.""" + +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .context import SteeringContext, SteeringContextCallback, SteeringContextProvider +from .handler import SteeringHandler + +__all__ = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContext", + "SteeringContextCallback", + "SteeringContextProvider", +] diff --git a/src/strands/vended_plugins/steering/core/action.py b/src/strands/vended_plugins/steering/core/action.py new file mode 100644 index 000000000..b1f124b40 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/action.py @@ -0,0 +1,76 @@ +"""SteeringAction types for steering evaluation results. + +Defines structured outcomes from steering handlers that determine how agent actions +should be handled. SteeringActions enable modular prompting by providing just-in-time +feedback rather than front-loading all instructions in monolithic prompts. + +Flow: + SteeringHandler.steer_*() → SteeringAction → Event handling + ↓ ↓ ↓ + Evaluate context Action type Execution modified + +SteeringAction types: + Proceed: Allow execution to continue without intervention + Guide: Provide contextual guidance to redirect the agent + Interrupt: Pause execution for human input + +Extensibility: + New action types can be added to the union. Always handle the default + case in pattern matching to maintain backward compatibility. +""" + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + + +class Proceed(BaseModel): + """Allow execution to continue without intervention. + + The action proceeds as planned. The reason provides context + for logging and debugging purposes. + """ + + type: Literal["proceed"] = "proceed" + reason: str + + +class Guide(BaseModel): + """Provide contextual guidance to redirect the agent. + + The agent receives the reason as contextual feedback to help guide + its behavior. The specific handling depends on the steering context + (e.g., tool call vs. model response). + """ + + type: Literal["guide"] = "guide" + reason: str + + +class Interrupt(BaseModel): + """Pause execution for human input via interrupt system. + + Execution is paused and human input is requested through Strands' + interrupt system. The human can approve or deny the operation, and their + decision determines whether execution continues or is cancelled. + """ + + type: Literal["interrupt"] = "interrupt" + reason: str + + +# Context-specific steering action types +ToolSteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")] +"""Steering actions valid for tool steering (steer_before_tool). + +- Proceed: Allow tool execution to continue +- Guide: Cancel tool and provide feedback for alternative approaches +- Interrupt: Pause for human input before tool execution +""" + +ModelSteeringAction = Annotated[Proceed | Guide, Field(discriminator="type")] +"""Steering actions valid for model steering (steer_after_model). + +- Proceed: Accept model response without modification +- Guide: Discard model response and retry with guidance +""" diff --git a/src/strands/vended_plugins/steering/core/context.py b/src/strands/vended_plugins/steering/core/context.py new file mode 100644 index 000000000..446c4c9f9 --- /dev/null +++ b/src/strands/vended_plugins/steering/core/context.py @@ -0,0 +1,77 @@ +"""Steering context protocols for contextual guidance. + +Defines protocols for context callbacks and providers that populate +steering context data used by handlers to make guidance decisions. + +Architecture: + SteeringContextCallback → Handler.steering_context → SteeringHandler.steer() + ↓ ↓ ↓ + Update local context Store in handler Access via self.steering_context + +Context lifecycle: + 1. Handler registers context callbacks for hook events + 2. Callbacks update handler's local steering_context on events + 3. Handler accesses self.steering_context in steer() method + 4. Context persists across calls within handler instance + +Implementation: + Each handler maintains its own JSONSerializableDict context. + Callbacks are registered per handler instance for isolation. + Providers can supply multiple callbacks for different events. +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Generic, TypeVar, cast, get_args, get_origin + +from ....hooks.registry import HookEvent +from ....types.json_dict import JSONSerializableDict + +logger = logging.getLogger(__name__) + + +@dataclass +class SteeringContext: + """Container for steering context data.""" + + """Container for steering context data. + + This class should not be instantiated directly - it is intended for internal use only. + """ + + data: JSONSerializableDict = field(default_factory=JSONSerializableDict) + + +EventType = TypeVar("EventType", bound=HookEvent, contravariant=True) + + +class SteeringContextCallback(ABC, Generic[EventType]): + """Abstract base class for steering context update callbacks.""" + + @property + def event_type(self) -> type[HookEvent]: + """Return the event type this callback handles.""" + for base in getattr(self.__class__, "__orig_bases__", ()): + if get_origin(base) is SteeringContextCallback: + return cast(type[HookEvent], get_args(base)[0]) + raise ValueError("Could not determine event type from generic parameter") + + def __call__(self, event: EventType, steering_context: "SteeringContext", **kwargs: Any) -> None: + """Update steering context based on hook event. + + Args: + event: The hook event that triggered the callback + steering_context: The steering context to update + **kwargs: Additional keyword arguments for context updates + """ + ... + + +class SteeringContextProvider(ABC): + """Abstract base class for context providers that handle multiple event types.""" + + @abstractmethod + def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]: + """Return list of context callbacks with event types extracted from generics.""" + ... diff --git a/src/strands/vended_plugins/steering/core/handler.py b/src/strands/vended_plugins/steering/core/handler.py new file mode 100644 index 000000000..214118d4f --- /dev/null +++ b/src/strands/vended_plugins/steering/core/handler.py @@ -0,0 +1,218 @@ +"""Steering handler base class for providing contextual guidance to agents. + +Provides modular prompting through contextual guidance that appears when relevant, +rather than front-loading all instructions. Handlers integrate with the Strands hook +system to intercept actions and provide just-in-time feedback based on local context. + +Architecture: + Hook Event → Context Callbacks → Update steering_context → steer_*() → SteeringAction + ↓ ↓ ↓ ↓ ↓ + Hook triggered Populate context Handler evaluates Handler decides Action taken + +Lifecycle: + 1. Context callbacks update handler's steering_context on hook events + 2. BeforeToolCallEvent triggers steer_before_tool() for tool steering + 3. AfterModelCallEvent triggers steer_after_model() for model steering + 4. Handler accesses self.steering_context for guidance decisions + 5. SteeringAction determines execution flow + +Implementation: + Subclass SteeringHandler and override steer_before_tool() and/or steer_after_model(). + Both methods have default implementations that return Proceed, so you only need to + override the methods you want to customize. + Pass context_providers in constructor to register context update functions. + Each handler maintains isolated steering_context that persists across calls. + +SteeringAction handling for steer_before_tool: + Proceed: Tool executes immediately + Guide: Tool cancelled, agent receives contextual feedback to explore alternatives + Interrupt: Tool execution paused for human input via interrupt system + +SteeringAction handling for steer_after_model: + Proceed: Model response accepted without modification + Guide: Discard model response and retry (message is dropped, model is called again) + Interrupt: Model response handling paused for human input via interrupt system +""" + +import logging +from typing import TYPE_CHECKING, Any + +from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from ....plugins import Plugin, hook +from ....types.content import Message +from ....types.streaming import StopReason +from ....types.tools import ToolUse +from .action import Guide, Interrupt, ModelSteeringAction, Proceed, ToolSteeringAction +from .context import SteeringContext, SteeringContextProvider + +if TYPE_CHECKING: + from ....agent import Agent + +logger = logging.getLogger(__name__) + + +class SteeringHandler(Plugin): + """Base class for steering handlers that provide contextual guidance to agents. + + Steering handlers maintain local context and register hook callbacks + to populate context data as needed for guidance decisions. + """ + + name: str = "steering" + + def __init__(self, context_providers: list[SteeringContextProvider] | None = None): + """Initialize the steering handler. + + Args: + context_providers: List of context providers for context updates + """ + super().__init__() + self.steering_context = SteeringContext() + self._context_callbacks = [] + + # Collect callbacks from all providers + for provider in context_providers or []: + self._context_callbacks.extend(provider.context_providers()) + + logger.debug("handler_class=<%s> | initialized", self.__class__.__name__) + + def init_agent(self, agent: "Agent") -> None: + """Initialize the steering handler with an agent. + + Registers hook callbacks for steering guidance and context updates. + + Args: + agent: The agent instance to attach steering to. + """ + # Register context update callbacks + for callback in self._context_callbacks: + agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) + + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + """Provide steering guidance for tool call.""" + tool_name = event.tool_use["name"] + logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) + + try: + action = await self.steer_before_tool(agent=event.agent, tool_use=event.tool_use) + except Exception as e: + logger.debug("tool_name=<%s>, error=<%s> | tool steering handler guidance failed", tool_name, e) + return + + self._handle_tool_steering_action(action, event, tool_name) + + def _handle_tool_steering_action( + self, action: ToolSteeringAction, event: BeforeToolCallEvent, tool_name: str + ) -> None: + """Handle the steering action for tool calls by modifying tool execution flow. + + Proceed: Tool executes normally + Guide: Tool cancelled with contextual feedback for agent to consider alternatives + Interrupt: Tool execution paused for human input via interrupt system + """ + if isinstance(action, Proceed): + logger.debug("tool_name=<%s> | tool call proceeding", tool_name) + elif isinstance(action, Guide): + logger.debug("tool_name=<%s> | tool call guided: %s", tool_name, action.reason) + event.cancel_tool = f"Tool call cancelled. {action.reason} You MUST follow this guidance immediately." + elif isinstance(action, Interrupt): + logger.debug("tool_name=<%s> | tool call requires human input: %s", tool_name, action.reason) + can_proceed: bool = event.interrupt(name=f"steering_input_{tool_name}", reason={"message": action.reason}) + logger.debug("tool_name=<%s> | received human input for tool call", tool_name) + + if not can_proceed: + event.cancel_tool = f"Manual approval denied: {action.reason}" + logger.debug("tool_name=<%s> | tool call denied by manual approval", tool_name) + else: + logger.debug("tool_name=<%s> | tool call approved manually", tool_name) + else: + raise ValueError(f"Unknown steering action type for tool call: {action}") + + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + """Provide steering guidance for model response.""" + logger.debug("providing model steering guidance") + + # Only steer on successful model responses + if event.stop_response is None: + logger.debug("no stop response available | skipping model steering") + return + + try: + action = await self.steer_after_model( + agent=event.agent, message=event.stop_response.message, stop_reason=event.stop_response.stop_reason + ) + except Exception as e: + logger.debug("error=<%s> | model steering handler guidance failed", e) + return + + await self._handle_model_steering_action(action, event) + + async def _handle_model_steering_action(self, action: ModelSteeringAction, event: AfterModelCallEvent) -> None: + """Handle the steering action for model responses by modifying response handling flow. + + Proceed: Model response accepted without modification + Guide: Discard model response and retry with guidance message added to conversation + """ + if isinstance(action, Proceed): + logger.debug("model response proceeding") + elif isinstance(action, Guide): + logger.debug("model response guided (retrying): %s", action.reason) + # Set retry flag to discard current response + event.retry = True + # Add guidance message to agent's conversation so model sees it on retry + await event.agent._append_messages({"role": "user", "content": [{"text": action.reason}]}) + logger.debug("added guidance message to conversation for model retry") + else: + raise ValueError(f"Unknown steering action type for model response: {action}") + + async def steer_before_tool(self, *, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance before tool execution. + + This method is called before a tool is executed, allowing the handler to: + - Proceed: Allow tool execution to continue + - Guide: Cancel tool and provide feedback for alternative approaches + - Interrupt: Pause for human input before tool execution + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ToolSteeringAction indicating how to guide the tool execution + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (allow tool execution) + Override this method to implement custom tool steering logic + """ + return Proceed(reason="Default implementation: allowing tool execution") + + async def steer_after_model( + self, *, agent: "Agent", message: Message, stop_reason: StopReason, **kwargs: Any + ) -> ModelSteeringAction: + """Provide contextual guidance after model response. + + This method is called after the model generates a response, allowing the handler to: + - Proceed: Accept the model response without modification + - Guide: Discard the response and retry (message is dropped, model is called again) + + Note: Interrupt is not supported for model steering as the model has already responded. + + Args: + agent: The agent instance + message: The model's generated message + stop_reason: The reason the model stopped generating + **kwargs: Additional keyword arguments for guidance evaluation + + Returns: + ModelSteeringAction indicating how to handle the model response + + Note: + Access steering context via self.steering_context + Default implementation returns Proceed (accept response as-is) + Override this method to implement custom model steering logic + """ + return Proceed(reason="Default implementation: accepting model response") diff --git a/src/strands/vended_plugins/steering/handlers/__init__.py b/src/strands/vended_plugins/steering/handlers/__init__.py new file mode 100644 index 000000000..fe364a5a2 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/__init__.py @@ -0,0 +1,5 @@ +"""Steering handler implementations.""" + +from collections.abc import Sequence + +__all__: Sequence[str] = [] diff --git a/src/strands/vended_plugins/steering/handlers/llm/__init__.py b/src/strands/vended_plugins/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..4dcccbe80 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM steering handler with prompt mapping.""" + +from .llm_handler import LLMSteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper, ToolUse + +__all__ = ["LLMSteeringHandler", "LLMPromptMapper", "DefaultPromptMapper", "ToolUse"] diff --git a/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py b/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py new file mode 100644 index 000000000..6d0a31eeb --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/llm_handler.py @@ -0,0 +1,99 @@ +"""LLM-based steering handler that uses an LLM to provide contextual guidance.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel, Field + +from .....models import Model +from .....types.tools import ToolUse +from ...context_providers.ledger_provider import LedgerProvider +from ...core.action import Guide, Interrupt, Proceed, ToolSteeringAction +from ...core.context import SteeringContextProvider +from ...core.handler import SteeringHandler +from .mappers import DefaultPromptMapper, LLMPromptMapper + +if TYPE_CHECKING: + from .....agent import Agent + +logger = logging.getLogger(__name__) + + +class _LLMSteering(BaseModel): + """Structured output model for LLM steering decisions.""" + + decision: Literal["proceed", "guide", "interrupt"] = Field( + description="Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'interrupt' for human input" + ) + reason: str = Field(description="Clear explanation of the steering decision and any guidance provided") + + +class LLMSteeringHandler(SteeringHandler): + """Steering handler that uses an LLM to provide contextual guidance. + + Uses natural language prompts to evaluate tool calls and provide + contextual steering guidance to help agents navigate complex workflows. + """ + + def __init__( + self, + system_prompt: str, + prompt_mapper: LLMPromptMapper | None = None, + model: Model | None = None, + context_providers: list[SteeringContextProvider] | None = None, + ): + """Initialize the LLMSteeringHandler. + + Args: + system_prompt: System prompt defining steering guidance rules + prompt_mapper: Custom prompt mapper for evaluation prompts + model: Optional model override for steering evaluation + context_providers: List of context providers for populating steering context. + Defaults to [LedgerProvider()] if None. Pass an empty list to disable + context providers. + """ + providers: list[SteeringContextProvider] = ( + [LedgerProvider()] if context_providers is None else context_providers + ) + super().__init__(context_providers=providers) + self.system_prompt = system_prompt + self.prompt_mapper = prompt_mapper or DefaultPromptMapper() + self.model = model + + async def steer_before_tool(self, *, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> ToolSteeringAction: + """Provide contextual guidance for tool usage. + + Args: + agent: The agent instance + tool_use: The tool use object with name and arguments + **kwargs: Additional keyword arguments for steering evaluation + + Returns: + SteeringAction indicating how to guide the tool execution + """ + # Generate steering prompt + prompt = self.prompt_mapper.create_steering_prompt(self.steering_context, tool_use=tool_use) + + # Create isolated agent for steering evaluation (no shared conversation state) + from .....agent import Agent + + steering_agent = Agent(system_prompt=self.system_prompt, model=self.model or agent.model, callback_handler=None) + + # Get LLM decision + llm_result: _LLMSteering = cast( + _LLMSteering, steering_agent(prompt, structured_output_model=_LLMSteering).structured_output + ) + + # Convert LLM decision to steering action + match llm_result.decision: + case "proceed": + return Proceed(reason=llm_result.reason) + case "guide": + return Guide(reason=llm_result.reason) + case "interrupt": + return Interrupt(reason=llm_result.reason) + case _: + logger.warning("decision=<%s> | unknown llm decision, defaulting to proceed", llm_result.decision) # type: ignore[unreachable] + return Proceed(reason="Unknown LLM decision, defaulting to proceed") diff --git a/src/strands/vended_plugins/steering/handlers/llm/mappers.py b/src/strands/vended_plugins/steering/handlers/llm/mappers.py new file mode 100644 index 000000000..ade018d32 --- /dev/null +++ b/src/strands/vended_plugins/steering/handlers/llm/mappers.py @@ -0,0 +1,130 @@ +"""LLM steering prompt mappers for generating evaluation prompts.""" + +import json +from typing import Any, Protocol + +from .....types.tools import ToolUse +from ...core.context import SteeringContext + +# Agent SOP format - see https://github.com/strands-agents/agent-sop +_STEERING_PROMPT_TEMPLATE = """# Steering Evaluation + +## Overview + +You are a STEERING AGENT that evaluates a {action_type} that ANOTHER AGENT is attempting to make. +Your job is to provide contextual guidance to help the other agent navigate workflows effectively. +You act as a safety net that can intervene when patterns in the context data suggest the agent +should try a different approach or get human input. + +**YOUR ROLE:** +- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) +- Provide just-in-time guidance when the agent is going down an ineffective path +- Allow normal operations to proceed when context shows no issues + +**CRITICAL CONSTRAINTS:** +- Base decisions ONLY on the context data provided below +- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT make assumptions about what tools "should" or "shouldn't" do +- Focus ONLY on patterns in the context data + +## Context + +{context_str} + +### Understanding Ledger Tool States + +If the context includes a ledger with tool_calls, the "status" field indicates: + +- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). +This is NOT a duplicate call - it's the tool you're deciding whether to approve. +The tool has NOT started executing yet. +- **"success"**: The tool completed successfully in a previous turn +- **"error"**: The tool failed or was cancelled in a previous turn + +**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, +that IS the current tool being evaluated. +It is NOT already executing or a duplicate. + +## Event to Evaluate + +{event_description} + +## Steps + +### 1. Analyze the {action_type_title} + +Review ONLY the context data above. Look for patterns in the data that indicate: + +- Previous failures or successes with this tool +- Frequency of attempts +- Any relevant tracking information + +**Constraints:** +- You MUST base analysis ONLY on the provided context data +- You MUST NOT use external knowledge about tool purposes or domains +- You SHOULD identify patterns in the context data +- You MAY reference relevant context data to inform your decision + +### 2. Make Steering Decision + +**Constraints:** +- You MUST respond with exactly one of: "proceed", "guide", or "interrupt" +- You MUST base the decision ONLY on context data patterns +- Your reason will be shown to the AGENT as guidance + +**Decision Options:** +- "proceed" if context data shows no concerning patterns +- "guide" if context data shows patterns requiring intervention +- "interrupt" if context data shows patterns requiring human input +""" + + +class LLMPromptMapper(Protocol): + """Protocol for mapping context and events to LLM evaluation prompts.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create steering prompt for LLM evaluation. + + Args: + steering_context: Steering context with populated data + tool_use: Tool use object for tool call events (None for other events) + **kwargs: Additional event data for other steering events + + Returns: + Formatted prompt string for LLM evaluation + """ + ... + + +class DefaultPromptMapper(LLMPromptMapper): + """Default prompt mapper for steering evaluation.""" + + def create_steering_prompt( + self, steering_context: SteeringContext, tool_use: ToolUse | None = None, **kwargs: Any + ) -> str: + """Create default steering prompt using Agent SOP structure. + + Uses Agent SOP format for structured, constraint-based prompts. + See: https://github.com/strands-agents/agent-sop + """ + context_str = ( + json.dumps(steering_context.data.get(), indent=2) if steering_context.data.get() else "No context available" + ) + + if tool_use: + event_description = ( + f"Tool: {tool_use['name']}\nArguments: {json.dumps(tool_use.get('input', {}), indent=2)}" + ) + action_type = "tool call" + else: + event_description = "General evaluation" + action_type = "action" + + return _STEERING_PROMPT_TEMPLATE.format( + action_type=action_type, + action_type_title=action_type.title(), + context_str=context_str, + event_description=event_description, + ) diff --git a/tests/strands/experimental/steering/test_steering_aliases.py b/tests/strands/experimental/steering/test_steering_aliases.py new file mode 100644 index 000000000..25fd86eb4 --- /dev/null +++ b/tests/strands/experimental/steering/test_steering_aliases.py @@ -0,0 +1,176 @@ +"""Tests to verify that experimental steering aliases work with deprecation warning. + +This test module ensures that the experimental steering aliases maintain +backwards compatibility and can be used interchangeably with the actual +types from strands.vended_plugins.steering. +""" + +import importlib +import sys +import warnings + +import pytest + +from strands.vended_plugins.steering import ( + Guide, + Interrupt, + LedgerAfterToolCall, + LedgerBeforeToolCall, + LedgerProvider, + LLMPromptMapper, + LLMSteeringHandler, + ModelSteeringAction, + Proceed, + SteeringContextCallback, + SteeringContextProvider, + SteeringHandler, + ToolSteeringAction, +) + +_ALL_NAMES = [ + "ToolSteeringAction", + "ModelSteeringAction", + "Proceed", + "Guide", + "Interrupt", + "SteeringHandler", + "SteeringContextCallback", + "SteeringContextProvider", + "LedgerBeforeToolCall", + "LedgerAfterToolCall", + "LedgerProvider", + "LLMSteeringHandler", + "LLMPromptMapper", +] + +_PRODUCTION_TYPES = { + "ToolSteeringAction": ToolSteeringAction, + "ModelSteeringAction": ModelSteeringAction, + "Proceed": Proceed, + "Guide": Guide, + "Interrupt": Interrupt, + "SteeringHandler": SteeringHandler, + "SteeringContextCallback": SteeringContextCallback, + "SteeringContextProvider": SteeringContextProvider, + "LedgerBeforeToolCall": LedgerBeforeToolCall, + "LedgerAfterToolCall": LedgerAfterToolCall, + "LedgerProvider": LedgerProvider, + "LLMSteeringHandler": LLMSteeringHandler, + "LLMPromptMapper": LLMPromptMapper, +} + + +@pytest.mark.parametrize("name", _ALL_NAMES) +def test_experimental_alias_is_same_type(name): + """Verify that experimental steering alias is identical to the actual type.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + from strands.experimental import steering + + experimental_type = getattr(steering, name) + + assert experimental_type is _PRODUCTION_TYPES[name] + + +@pytest.mark.parametrize("name", _ALL_NAMES) +def test_deprecation_warning_on_access(name, captured_warnings): + """Verify that accessing deprecated aliases emits deprecation warning.""" + # Clear the module from cache to trigger fresh import + if "strands.experimental.steering" in sys.modules: + del sys.modules["strands.experimental.steering"] + + # Clear any existing warnings + captured_warnings.clear() + + # Access from experimental - this should trigger the warning + from strands.experimental import steering + + _ = getattr(steering, name) + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert name in str(warning.message) + assert "strands.vended_plugins.steering" in str(warning.message) + + +def test_attribute_error_on_unknown_attribute(): + """Verify that accessing unknown attributes raises AttributeError.""" + import strands.experimental.steering as steering_module + + with pytest.raises(AttributeError, match="has no attribute"): + _ = steering_module.NonExistentClass + + +def test_no_warning_on_production_import(captured_warnings): + """Verify that importing from strands.vended_plugins.steering does not emit deprecation warning.""" + # Clear any existing warnings + captured_warnings.clear() + + # Import from production - should NOT trigger warning + from strands.vended_plugins.steering import Proceed as _ # noqa: F401 + + # Filter for steering-related deprecation warnings + steering_warnings = [ + w + for w in captured_warnings + if "has been moved" in str(w.message) and issubclass(w.category, DeprecationWarning) + ] + + assert len(steering_warnings) == 0 + + +# Submodule import tests - verify deep import paths still work with deprecation warnings + +_SUBMODULE_IMPORTS = [ + ("strands.experimental.steering.core.action", "Guide", Guide), + ("strands.experimental.steering.core.action", "Interrupt", Interrupt), + ("strands.experimental.steering.core.action", "Proceed", Proceed), + ("strands.experimental.steering.core.context", "SteeringContextCallback", SteeringContextCallback), + ("strands.experimental.steering.core.context", "SteeringContextProvider", SteeringContextProvider), + ("strands.experimental.steering.core.handler", "SteeringHandler", SteeringHandler), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerProvider", LedgerProvider), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerBeforeToolCall", LedgerBeforeToolCall), + ("strands.experimental.steering.context_providers.ledger_provider", "LedgerAfterToolCall", LedgerAfterToolCall), + ("strands.experimental.steering.handlers.llm.llm_handler", "LLMSteeringHandler", LLMSteeringHandler), + ("strands.experimental.steering.handlers.llm.mappers", "DefaultPromptMapper", None), +] + + +@pytest.mark.parametrize( + "module_path,attr_name,expected_type", + _SUBMODULE_IMPORTS, + ids=[f"{m}.{a}" for m, a, _ in _SUBMODULE_IMPORTS], +) +def test_submodule_import_resolves_correctly(module_path, attr_name, expected_type): + """Verify that submodule imports resolve to the correct production types.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + mod = importlib.import_module(module_path) + obj = getattr(mod, attr_name) + + if expected_type is not None: + assert obj is expected_type + + +@pytest.mark.parametrize( + "module_path,attr_name,expected_type", + _SUBMODULE_IMPORTS, + ids=[f"{m}.{a}" for m, a, _ in _SUBMODULE_IMPORTS], +) +def test_submodule_import_emits_deprecation_warning(module_path, attr_name, expected_type, captured_warnings): + """Verify that submodule imports emit deprecation warnings.""" + # Clear module from cache to trigger fresh import + if module_path in sys.modules: + del sys.modules[module_path] + + captured_warnings.clear() + + mod = importlib.import_module(module_path) + _ = getattr(mod, attr_name) + + assert len(captured_warnings) >= 1 + warning = captured_warnings[0] + assert issubclass(warning.category, DeprecationWarning) + assert attr_name in str(warning.message) + assert "has been moved to production" in str(warning.message) diff --git a/tests/strands/plugins/skills/__init__.py b/tests/strands/plugins/skills/__init__.py deleted file mode 100644 index 9bd23c0ed..000000000 --- a/tests/strands/plugins/skills/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for the skills plugin package.""" diff --git a/tests/strands/experimental/steering/context_providers/__init__.py b/tests/strands/vended_plugins/__init__.py similarity index 100% rename from tests/strands/experimental/steering/context_providers/__init__.py rename to tests/strands/vended_plugins/__init__.py diff --git a/tests/strands/experimental/steering/core/__init__.py b/tests/strands/vended_plugins/skills/__init__.py similarity index 100% rename from tests/strands/experimental/steering/core/__init__.py rename to tests/strands/vended_plugins/skills/__init__.py diff --git a/tests/strands/plugins/skills/test_agent_skills.py b/tests/strands/vended_plugins/skills/test_agent_skills.py similarity index 98% rename from tests/strands/plugins/skills/test_agent_skills.py rename to tests/strands/vended_plugins/skills/test_agent_skills.py index 8c6ab10bd..52802a6c1 100644 --- a/tests/strands/plugins/skills/test_agent_skills.py +++ b/tests/strands/vended_plugins/skills/test_agent_skills.py @@ -7,9 +7,9 @@ from strands.hooks.events import BeforeInvocationEvent from strands.hooks.registry import HookRegistry from strands.plugins.registry import _PluginRegistry -from strands.plugins.skills.agent_skills import AgentSkills -from strands.plugins.skills.skill import Skill from strands.types.tools import ToolContext +from strands.vended_plugins.skills.agent_skills import AgentSkills +from strands.vended_plugins.skills.skill import Skill def _make_skill(name: str = "test-skill", description: str = "A test skill", instructions: str = "Do the thing."): @@ -211,8 +211,6 @@ def test_set_available_skills_with_mixed_sources(self, tmp_path): assert names == {"fs-skill", "direct"} - - class TestSkillsTool: """Tests for the skills tool method.""" @@ -666,12 +664,6 @@ def test_resolve_nonexistent_path(self, tmp_path): class TestImports: """Tests for module imports.""" - def test_import_from_plugins(self): - """Test importing AgentSkills from strands.plugins.""" - from strands.plugins import AgentSkills as SP - - assert SP is AgentSkills - def test_import_skill_from_strands(self): """Test importing Skill from top-level strands package.""" from strands import Skill as S @@ -679,8 +671,8 @@ def test_import_skill_from_strands(self): assert S is Skill def test_import_from_skills_package(self): - """Test importing from strands.plugins.skills package.""" - from strands.plugins.skills import AgentSkills, Skill + """Test importing from strands.vended_plugins.skills package.""" + from strands.vended_plugins.skills import AgentSkills, Skill assert Skill is not None assert AgentSkills is not None diff --git a/tests/strands/plugins/skills/test_skill.py b/tests/strands/vended_plugins/skills/test_skill.py similarity index 99% rename from tests/strands/plugins/skills/test_skill.py rename to tests/strands/vended_plugins/skills/test_skill.py index 2c4c21930..53d6f3507 100644 --- a/tests/strands/plugins/skills/test_skill.py +++ b/tests/strands/vended_plugins/skills/test_skill.py @@ -5,7 +5,7 @@ import pytest -from strands.plugins.skills.skill import ( +from strands.vended_plugins.skills.skill import ( Skill, _find_skill_md, _fix_yaml_colons, diff --git a/tests/strands/experimental/steering/handlers/__init__.py b/tests/strands/vended_plugins/steering/__init__.py similarity index 100% rename from tests/strands/experimental/steering/handlers/__init__.py rename to tests/strands/vended_plugins/steering/__init__.py diff --git a/tests/strands/experimental/steering/handlers/llm/__init__.py b/tests/strands/vended_plugins/steering/context_providers/__init__.py similarity index 100% rename from tests/strands/experimental/steering/handlers/llm/__init__.py rename to tests/strands/vended_plugins/steering/context_providers/__init__.py diff --git a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py b/tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py similarity index 91% rename from tests/strands/experimental/steering/context_providers/test_ledger_provider.py rename to tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py index c3cde475b..dda718f31 100644 --- a/tests/strands/experimental/steering/context_providers/test_ledger_provider.py +++ b/tests/strands/vended_plugins/steering/context_providers/test_ledger_provider.py @@ -2,13 +2,13 @@ from unittest.mock import Mock, patch -from strands.experimental.steering.context_providers.ledger_provider import ( +from strands.hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from strands.vended_plugins.steering.context_providers.ledger_provider import ( LedgerAfterToolCall, LedgerBeforeToolCall, LedgerProvider, ) -from strands.experimental.steering.core.context import SteeringContext -from strands.hooks.events import AfterToolCallEvent, BeforeToolCallEvent +from strands.vended_plugins.steering.core.context import SteeringContext def test_context_providers_method(): @@ -22,7 +22,7 @@ def test_context_providers_method(): assert isinstance(callbacks[1], LedgerAfterToolCall) -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_ledger_before_tool_call_new_ledger(mock_datetime): """Test LedgerBeforeToolCall with new ledger.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" @@ -48,7 +48,7 @@ def test_ledger_before_tool_call_new_ledger(mock_datetime): assert tool_call["status"] == "pending" -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_ledger_before_tool_call_existing_ledger(mock_datetime): """Test LedgerBeforeToolCall with existing ledger.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" @@ -77,7 +77,7 @@ def test_ledger_before_tool_call_existing_ledger(mock_datetime): assert ledger["tool_calls"][1]["tool_name"] == "new_tool" -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_ledger_after_tool_call_success(mock_datetime): """Test LedgerAfterToolCall with successful completion.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:05:00" @@ -135,7 +135,7 @@ def test_ledger_after_tool_call_no_calls(): def test_session_start_persistence(): """Test that session_start is set during initialization and persists.""" - with patch("strands.experimental.steering.context_providers.ledger_provider.datetime") as mock_datetime: + with patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") as mock_datetime: mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T10:00:00" callback = LedgerBeforeToolCall() @@ -143,7 +143,7 @@ def test_session_start_persistence(): assert callback.session_start == "2024-01-01T10:00:00" -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_parallel_tool_calls_all_pending(mock_datetime): """Test multiple tool calls added as pending before any execute.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" @@ -163,7 +163,7 @@ def test_parallel_tool_calls_all_pending(mock_datetime): assert [call["tool_name"] for call in ledger["tool_calls"]] == ["tool_a", "tool_b", "tool_c"] -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_parallel_tool_calls_complete_by_id(mock_datetime): """Test tool calls complete in any order by matching toolUseId.""" # Need timestamps for: session_start + 3 tool calls + 1 completion @@ -199,7 +199,7 @@ def test_parallel_tool_calls_complete_by_id(mock_datetime): assert ledger["tool_calls"][2]["status"] == "pending" -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_parallel_tool_calls_complete_all_out_of_order(mock_datetime): """Test all parallel tool calls complete in reverse order.""" # Need timestamps for: session_start + 3 tool calls + 3 completions @@ -238,7 +238,7 @@ def test_parallel_tool_calls_complete_all_out_of_order(mock_datetime): assert ledger["tool_calls"][2]["result"] == ["result_2"] -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_parallel_tool_calls_with_failure(mock_datetime): """Test parallel tool calls where one fails.""" # Need timestamps for: session_start + 2 tool calls + 2 completions @@ -281,7 +281,7 @@ def test_parallel_tool_calls_with_failure(mock_datetime): assert ledger["tool_calls"][1]["error"] == "test error" -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_after_tool_call_no_matching_id(mock_datetime): """Test AfterToolCallEvent when tool_use_id doesn't match any pending call.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" @@ -308,7 +308,7 @@ def test_after_tool_call_no_matching_id(mock_datetime): assert "completion_timestamp" not in ledger["tool_calls"][0] -@patch("strands.experimental.steering.context_providers.ledger_provider.datetime") +@patch("strands.vended_plugins.steering.context_providers.ledger_provider.datetime") def test_tool_use_id_stored_in_ledger(mock_datetime): """Test that toolUseId is stored in ledger entries.""" mock_datetime.now.return_value.isoformat.return_value = "2024-01-01T12:00:00" diff --git a/tests/strands/vended_plugins/steering/core/__init__.py b/tests/strands/vended_plugins/steering/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/vended_plugins/steering/core/test_handler.py similarity index 98% rename from tests/strands/experimental/steering/core/test_handler.py rename to tests/strands/vended_plugins/steering/core/test_handler.py index 1f247120a..dc3b0dacc 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/vended_plugins/steering/core/test_handler.py @@ -5,12 +5,16 @@ import pytest -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider -from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent from strands.hooks.registry import HookRegistry from strands.plugins import Plugin +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.core.context import ( + SteeringContext, + SteeringContextCallback, + SteeringContextProvider, +) +from strands.vended_plugins.steering.core.handler import SteeringHandler class TestSteeringHandler(SteeringHandler): diff --git a/tests/strands/vended_plugins/steering/handlers/__init__.py b/tests/strands/vended_plugins/steering/handlers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/steering/handlers/llm/__init__.py b/tests/strands/vended_plugins/steering/handlers/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py b/tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py similarity index 96% rename from tests/strands/experimental/steering/handlers/llm/test_llm_handler.py rename to tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py index f10254e50..776124d25 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_llm_handler.py +++ b/tests/strands/vended_plugins/steering/handlers/llm/test_llm_handler.py @@ -4,9 +4,9 @@ import pytest -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler, _LLMSteering -from strands.experimental.steering.handlers.llm.mappers import DefaultPromptMapper +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.handlers.llm.llm_handler import LLMSteeringHandler, _LLMSteering +from strands.vended_plugins.steering.handlers.llm.mappers import DefaultPromptMapper def test_llm_steering_handler_initialization(): diff --git a/tests/strands/experimental/steering/handlers/llm/test_mappers.py b/tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py similarity index 95% rename from tests/strands/experimental/steering/handlers/llm/test_mappers.py rename to tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py index 511671d3a..3f87f030a 100644 --- a/tests/strands/experimental/steering/handlers/llm/test_mappers.py +++ b/tests/strands/vended_plugins/steering/handlers/llm/test_mappers.py @@ -1,7 +1,7 @@ """Unit tests for LLM steering prompt mappers.""" -from strands.experimental.steering.core.context import SteeringContext -from strands.experimental.steering.handlers.llm.mappers import _STEERING_PROMPT_TEMPLATE, DefaultPromptMapper +from strands.vended_plugins.steering.core.context import SteeringContext +from strands.vended_plugins.steering.handlers.llm.mappers import _STEERING_PROMPT_TEMPLATE, DefaultPromptMapper def test_create_steering_prompt_with_tool_use(): diff --git a/tests_integ/steering/test_model_steering.py b/tests_integ/steering/test_model_steering.py index d1948586a..86c69fd50 100644 --- a/tests_integ/steering/test_model_steering.py +++ b/tests_integ/steering/test_model_steering.py @@ -1,11 +1,11 @@ """Integration tests for model steering (steer_after_model).""" from strands import Agent, tool -from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider -from strands.experimental.steering.core.action import Guide, ModelSteeringAction, Proceed -from strands.experimental.steering.core.handler import SteeringHandler from strands.types.content import Message from strands.types.streaming import StopReason +from strands.vended_plugins.steering.context_providers.ledger_provider import LedgerProvider +from strands.vended_plugins.steering.core.action import Guide, ModelSteeringAction, Proceed +from strands.vended_plugins.steering.core.handler import SteeringHandler class SimpleModelSteeringHandler(SteeringHandler): diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py index e441e71da..52c715f5e 100644 --- a/tests_integ/steering/test_tool_steering.py +++ b/tests_integ/steering/test_tool_steering.py @@ -3,10 +3,10 @@ import pytest from strands import Agent, tool -from strands.experimental.steering.context_providers.ledger_provider import LedgerProvider -from strands.experimental.steering.core.action import Guide, Interrupt, Proceed -from strands.experimental.steering.core.handler import SteeringHandler -from strands.experimental.steering.handlers.llm.llm_handler import LLMSteeringHandler +from strands.vended_plugins.steering.context_providers.ledger_provider import LedgerProvider +from strands.vended_plugins.steering.core.action import Guide, Interrupt, Proceed +from strands.vended_plugins.steering.core.handler import SteeringHandler +from strands.vended_plugins.steering.handlers.llm.llm_handler import LLMSteeringHandler @tool diff --git a/tests_integ/test_skills_plugin.py b/tests_integ/test_skills_plugin.py index 160ae65a0..8867f08fd 100644 --- a/tests_integ/test_skills_plugin.py +++ b/tests_integ/test_skills_plugin.py @@ -8,7 +8,7 @@ import pytest from strands import Agent -from strands.plugins.skills import AgentSkills, Skill +from strands.vended_plugins.skills import AgentSkills, Skill SUMMARIZATION_SKILL = Skill( name="summarization", From b9f5b904b49d6c8bce5699d749f2b5704c0143a9 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 11 Mar 2026 12:02:28 -0400 Subject: [PATCH 179/279] fix: break circular references so Agent cleanup doesn't hang with MCPClient (#1830) --- src/strands/plugins/registry.py | 11 ++- src/strands/tools/_caller.py | 11 ++- tests/strands/plugins/test_plugins.py | 13 +++ tests/strands/tools/test_caller.py | 122 ++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 2 deletions(-) diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index a75858680..e994b5591 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -6,6 +6,7 @@ import inspect import logging +import weakref from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, cast @@ -55,9 +56,17 @@ def __init__(self, agent: "Agent") -> None: Args: agent: The agent instance that plugins will be initialized with. """ - self._agent = agent + self._agent_ref = weakref.ref(agent) self._plugins: dict[str, Plugin] = {} + @property + def _agent(self) -> "Agent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent + def add_and_init(self, plugin: Plugin) -> None: """Add and initialize a plugin with the agent. diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 8ca6138fc..0b5408f35 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,6 +9,7 @@ import json import random +import weakref from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -35,7 +36,15 @@ def __init__(self, agent: "Agent | BidiAgent") -> None: """ # WARNING: Do not add any other member variables or methods as this could result in a name conflict with # agent tools and thus break their execution. - self._agent = agent + self._agent_ref = weakref.ref(agent) + + @property + def _agent(self) -> "Agent | BidiAgent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 04b39718b..88ed41f8d 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -1,9 +1,11 @@ """Tests for the plugin system.""" +import gc import unittest.mock import pytest +from strands import Agent from strands.hooks import HookRegistry from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry @@ -194,3 +196,14 @@ async def init_agent(self, agent): assert plugin.initialized assert mock_agent.async_plugin_initialized + + +def test_plugin_registry_raises_reference_error_after_agent_collected(): + """Verify _PluginRegistry raises ReferenceError when the Agent has been garbage collected.""" + agent = Agent() + registry = agent._plugin_registry + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = registry._agent diff --git a/tests/strands/tools/test_caller.py b/tests/strands/tools/test_caller.py index 18de6d3f0..2658af6b4 100644 --- a/tests/strands/tools/test_caller.py +++ b/tests/strands/tools/test_caller.py @@ -1,8 +1,11 @@ +import gc import unittest.mock +import weakref import pytest from strands import Agent, tool +from strands.tools.tool_provider import ToolProvider @pytest.fixture @@ -312,3 +315,122 @@ def test_agent_tool_caller_interrupt_activated(): exp_message = r"cannot directly call tool during interrupt" with pytest.raises(RuntimeError, match=exp_message): agent.tool.test_tool() + + +def test_agent_collected_without_cyclic_gc(): + """Verify that Agent is promptly collectable (no persistent reference cycle). + + This ensures that the weakref-based back-references in _ToolCaller and _PluginRegistry + do not create reference cycles that would delay cleanup until interpreter shutdown. + When cleanup is deferred to interpreter shutdown, MCPClient.stop() hangs because its + background thread cannot complete async cleanup at that point. + + Note: On some platforms/versions (e.g. Python 3.14 with deferred refcounting), del may + not immediately trigger collection. A single gc.collect() is allowed as a fallback since + it still proves no persistent cycle exists — the agent is collected promptly, not deferred + to interpreter shutdown. + """ + gc.disable() + try: + agent = Agent() + ref = weakref.ref(agent) + del agent + + if ref() is not None: + # Deferred refcounting (Python 3.14+) may not collect immediately on del; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert ref() is None, "Agent was not collected; a reference cycle likely exists" + finally: + gc.enable() + + +class _MockToolProvider(ToolProvider): + """Minimal ToolProvider that tracks cleanup calls, mimicking MCPClient lifecycle.""" + + def __init__(self): + self.consumers: set = set() + self.cleanup_called = False + + async def load_tools(self, **kwargs): + return [] + + def add_consumer(self, consumer_id, **kwargs): + self.consumers.add(consumer_id) + + def remove_consumer(self, consumer_id, **kwargs): + self.consumers.discard(consumer_id) + if not self.consumers: + self.cleanup_called = True + + +def test_agent_with_tool_provider_cleaned_up_when_function_returns(): + """Replicate the hang from issue #1732: Agent with MCPClient created inside a function. + + When an Agent using a managed MCPClient (as ToolProvider) is created inside a function, + the script used to hang on exit. The Agent went out of scope when the function returned, + but circular references (Agent → _ToolCaller._agent → Agent) prevented refcount-based + destruction. Cleanup was deferred to the cyclic GC during interpreter shutdown, where + MCPClient.stop() → thread.join() would hang. + + This test verifies that with the weakref fix, the Agent is destroyed immediately when + the function returns, and the tool provider's cleanup runs promptly. + """ + provider = _MockToolProvider() + + def get_agent(): + return Agent(tools=[provider]) + + def main(): + agent = get_agent() # noqa: F841 + + gc.disable() + try: + main() + + if not provider.cleanup_called: + # Deferred refcounting (Python 3.14+) may not collect immediately on scope exit; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert provider.cleanup_called, ( + "Tool provider was not cleaned up when the function returned; Agent likely leaked due to a reference cycle" + ) + finally: + gc.enable() + + +def test_agent_with_tool_provider_cleaned_up_on_del(): + """Replicate the working case from issue #1732: Agent at module scope, explicitly deleted. + + In the issue, an Agent created at module level did not hang because module-level variables + are cleared early during interpreter shutdown (while the runtime is still functional). + This test verifies the equivalent: explicitly deleting the agent triggers immediate cleanup. + """ + provider = _MockToolProvider() + + agent = Agent(tools=[provider]) + assert not provider.cleanup_called + + del agent + + if not provider.cleanup_called: + # Deferred refcounting (Python 3.14+) may not collect immediately on del; + # a single gc.collect() should still reclaim it since there are no cycles. + gc.collect() + + assert provider.cleanup_called, "Tool provider was not cleaned up after del agent" + + +def test_tool_caller_raises_reference_error_after_agent_collected(): + """Verify _ToolCaller raises ReferenceError when the Agent has been garbage collected.""" + agent = Agent() + caller = agent.tool_caller + # Clear the weak reference by replacing it directly + caller._agent_ref = weakref.ref(agent) + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = caller._agent From 2da3f7c16b852f23ea7a7cfedde2ca94b631c274 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Wed, 11 Mar 2026 14:13:46 -0400 Subject: [PATCH 180/279] fix: Set _is_new_session = False at the end of each initialize_* method (#1859) --- src/strands/session/repository_session_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index b3eed6474..7e538c08b 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -226,6 +226,8 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 agent.messages = self._fix_broken_tool_use(agent.messages) + self._is_new_session = False + def _fix_broken_tool_use(self, messages: list[Message]) -> list[Message]: """Fix broken tool use/result pairs in message history. @@ -318,6 +320,8 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) source.deserialize_state(state) + self._is_new_session = False + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: """Initialize a bidirectional agent with a session. @@ -375,6 +379,8 @@ def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 agent.messages = self._fix_broken_tool_use(agent.messages) + self._is_new_session = False + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: """Append a message to the bidirectional agent's session. From fca208b866bccecc260b72160f2c66fa6c600be5 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 12 Mar 2026 12:01:23 -0400 Subject: [PATCH 181/279] feat: pass A2A request context metadata as invocation state (#1854) --- src/strands/multiagent/a2a/executor.py | 6 +- tests/strands/multiagent/a2a/test_executor.py | 136 +++++++++++++++++- 2 files changed, 140 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 2f8de99f7..c8c00600b 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -128,8 +128,12 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater self._current_artifact_id = str(uuid.uuid4()) self._is_first_chunk = True + # Pass the A2A RequestContext through invocation state so downstream + # tools and hooks can access request metadata, task info, configuration, etc. + invocation_state: dict[str, Any] = {"a2a_request_context": context} + try: - async for event in self.agent.stream_async(content_blocks): + async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): await self._handle_streaming_event(event, updater) except Exception: logger.exception("Error in streaming execution") diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index 932f26247..dc90fbdd6 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -1,6 +1,7 @@ """Tests for the StrandsA2AExecutor class.""" import base64 +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -1196,4 +1197,137 @@ async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent): assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc" assert mock_updater.add_artifact.call_args[1]["append"] is True assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True - mock_updater.complete.assert_called_once() + + +# Tests for invocation state propagation from A2A request context + + +def _setup_streaming_context( + mock_strands_agent: MagicMock, + mock_request_context: MagicMock, +) -> None: + """Set up common mocks for invocation state streaming tests. + + Args: + mock_strands_agent: The mock Strands Agent. + mock_request_context: The mock RequestContext. + """ + + async def mock_stream(content_blocks: list, **kwargs: Any) -> Any: + yield {"result": MagicMock(spec=SAAgentResult)} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + # Set up message with a text part + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test input" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + +@pytest.mark.asyncio +async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that the full RequestContext is passed as a2a_request_context in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-42" + mock_task.context_id = "ctx-99" + mock_request_context.current_task = mock_task + mock_request_context.metadata = {"caller": "test-client"} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + mock_strands_agent.stream_async.assert_called_once() + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that metadata is accessible through the RequestContext in invocation state.""" + test_metadata = {"caller": "test-client", "session": "abc-123"} + mock_request_context.metadata = test_metadata + mock_task = MagicMock() + mock_task.id = "task-1" + mock_task.context_id = "ctx-1" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.metadata == test_metadata + + +@pytest.mark.asyncio +async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that task info is accessible through the RequestContext in invocation state.""" + mock_task = MagicMock() + mock_task.id = "task-100" + mock_task.context_id = "ctx-200" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + context = call_kwargs["invocation_state"]["a2a_request_context"] + + assert context.current_task.id == "task-100" + assert context.current_task.context_id == "ctx-200" + + +@pytest.mark.asyncio +async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that RequestContext is passed even when there is no current task.""" + mock_request_context.current_task = None + mock_request_context.metadata = {} + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent) + + with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: + mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx") + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state["a2a_request_context"] is mock_request_context + + +@pytest.mark.asyncio +async def test_invocation_state_with_a2a_compliant_streaming( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that invocation state is passed correctly in A2A-compliant streaming mode.""" + mock_task = MagicMock() + mock_task.id = "task-compliant" + mock_task.context_id = "ctx-compliant" + mock_request_context.current_task = mock_task + + _setup_streaming_context(mock_strands_agent, mock_request_context) + + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) + await executor.execute(mock_request_context, mock_event_queue) + + call_kwargs = mock_strands_agent.stream_async.call_args[1] + invocation_state = call_kwargs["invocation_state"] + + assert invocation_state is not None + assert invocation_state["a2a_request_context"] is mock_request_context From 39c8c198369cf4eddd9d64141d441b238035de44 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Mon, 16 Mar 2026 10:25:02 -0400 Subject: [PATCH 182/279] fix: s3session manager bug (#1915) --- src/strands/session/s3_session_manager.py | 5 +++- .../session/test_s3_session_manager.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 8d557e81c..fad5e4fd0 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -95,7 +95,10 @@ def _get_session_path(self, session_id: str) -> str: ValueError: If session id contains a path separator. """ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) - return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" + prefix = self.prefix.strip("/") + if prefix: + return f"{prefix}/{SESSION_PREFIX}{session_id}/" + return f"{SESSION_PREFIX}{session_id}/" def _get_agent_path(self, session_id: str, agent_id: str) -> str: """Get agent S3 prefix. diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index c1c89da5b..29bc97ab5 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -89,6 +89,17 @@ def test_init_s3_session_manager_with_existing_user_agent(mocked_aws, s3_bucket) assert "strands-agents" in session_manager.client.meta.config.user_agent_extra +def test_empty_prefix_session_roundtrip(mocked_aws, s3_bucket, sample_session, sample_agent): + """Test that session data can be written and read back with default empty prefix.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, prefix="", region_name="us-west-2") + manager.create_session(sample_session) + manager.create_agent(sample_session.session_id, sample_agent) + + result = manager.read_agent(sample_session.session_id, sample_agent.agent_id) + assert result is not None + assert result.agent_id == sample_agent.agent_id + + def test_create_session(s3_manager, sample_session): """Test creating a session in S3.""" result = s3_manager.create_session(sample_session) @@ -369,6 +380,24 @@ def test_update_nonexistent_message(s3_manager, sample_session, sample_agent, sa s3_manager.update_message(sample_session.session_id, sample_agent.agent_id, sample_message) +@pytest.mark.parametrize( + "prefix, expected_path", + [ + ("", "session_test-id/"), + ("sessions", "sessions/session_test-id/"), + ("sessions/", "sessions/session_test-id/"), + ("/sessions", "sessions/session_test-id/"), + ("/sessions/", "sessions/session_test-id/"), + ("a/b/c", "a/b/c/session_test-id/"), + ("a/b/c/", "a/b/c/session_test-id/"), + ], +) +def test__get_session_path_prefix_normalization(mocked_aws, s3_bucket, prefix, expected_path): + """Test that _get_session_path normalizes prefix to avoid leading or double slashes.""" + manager = S3SessionManager(session_id="test", bucket=s3_bucket, prefix=prefix, region_name="us-west-2") + assert manager._get_session_path("test-id") == expected_path + + @pytest.mark.parametrize( "session_id", [ From 2e4c82beb968c8e05fe0932316ec39f7c679f588 Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Mon, 16 Mar 2026 19:34:50 +0100 Subject: [PATCH 183/279] fix(graph): only evaluate outbound edges from completed nodes (#1846) Co-authored-by: giulio-leone --- src/strands/multiagent/graph.py | 11 ++++++-- tests/strands/multiagent/test_graph.py | 35 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 966d2a0b3..40a49cf7c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -827,9 +827,16 @@ async def _handle_node_timeout(self, node: GraphNode, event_queue: asyncio.Queue return timeout_exception def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: - """Find nodes that became ready after the last execution.""" + """Find nodes that became ready after the last execution. + + Only evaluates destination nodes of outbound edges from the completed batch, + instead of iterating over all nodes in the graph. + """ + # Collect unique candidate nodes reachable from the completed batch + candidates = {edge.to_node for edge in self.edges if edge.from_node in completed_batch} + newly_ready = [] - for _node_id, node in self.nodes.items(): + for node in candidates: if self._is_node_ready_with_conditions(node, completed_batch): newly_ready.append(node) return newly_ready diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8158bf4b1..8013021df 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2405,3 +2405,38 @@ async def stream_async(self, prompt=None, **kwargs): assert result.completed_nodes == 2 assert "custom_node" in result.results assert "regular_node" in result.results + + +def test_find_newly_ready_nodes_only_evaluates_outbound_edges(): + """Verify _find_newly_ready_nodes only checks destinations of outbound edges from completed batch. + + Previously, it iterated over ALL nodes, which could cause nodes to fire + before their actual dependencies completed. + + See: https://github.com/strands-agents/sdk-python/issues/685 + """ + # Build a graph: A -> B -> C, D -> E (independent chain) + node_a = GraphNode(node_id="A", executor=create_mock_agent("A")) + node_b = GraphNode(node_id="B", executor=create_mock_agent("B")) + node_c = GraphNode(node_id="C", executor=create_mock_agent("C")) + node_d = GraphNode(node_id="D", executor=create_mock_agent("D")) + node_e = GraphNode(node_id="E", executor=create_mock_agent("E")) + + graph = Graph.__new__(Graph) + graph.nodes = {"A": node_a, "B": node_b, "C": node_c, "D": node_d, "E": node_e} + graph.edges = [ + GraphEdge(from_node=node_a, to_node=node_b), + GraphEdge(from_node=node_b, to_node=node_c), + GraphEdge(from_node=node_d, to_node=node_e), + ] + graph.state = GraphState() + + # When A completes, only B should be ready (not E) + ready = graph._find_newly_ready_nodes([node_a]) + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"B"}, f"Expected only B, got {ready_ids}" + + # When D completes, only E should be ready (not B or C) + ready = graph._find_newly_ready_nodes([node_d]) + ready_ids = {n.node_id for n in ready} + assert ready_ids == {"E"}, f"Expected only E, got {ready_ids}" From b66534bef268e618b8a90fbcf5253036b547d91d Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Mon, 16 Mar 2026 20:31:29 +0100 Subject: [PATCH 184/279] fix(openai): always use string content for tool messages (#1878) Co-authored-by: giulio-leone --- src/strands/models/openai.py | 32 +++++-- tests/strands/models/test_openai.py | 126 +++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 2b217ad91..73484e924 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -204,13 +204,33 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> ], ) - formatted_contents = [cls.format_request_message_content(content) for content in contents] - - # If single text content, use string format for better model compatibility - if len(formatted_contents) == 1 and formatted_contents[0].get("type") == "text": - content: str | list[dict[str, Any]] = formatted_contents[0]["text"] + # Merge adjacent text blocks while preserving the order of non-text + # (image/document) content. When all content is text, join into a + # single string for broad compatibility with OpenAI-compatible + # endpoints (e.g., Kimi K2.5, vLLM, Ollama). + # See https://github.com/strands-agents/sdk-python/issues/1696 + merged: list[dict[str, Any]] = [] + has_non_text = False + for content_block in contents: + if "text" in content_block: + # Merge with the previous entry if it is also text (adjacent) + if merged and merged[-1].get("type") == "text": + merged[-1]["text"] += "\n" + content_block["text"] + else: + merged.append({"type": "text", "text": content_block["text"]}) + elif "image" in content_block or "document" in content_block: + has_non_text = True + merged.append(cls.format_request_message_content(content_block)) + + content: str | list[dict[str, Any]] + if has_non_text: + # Keep array format when images/documents are present so that + # _split_tool_message_images can extract them into a user message. + content = merged else: - content = formatted_contents + # All text — the loop already merged adjacent blocks with "\n", + # so extract the single resulting entry. + content = merged[0]["text"] if merged else "" return { "role": "tool", diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 241c22b64..747e1123a 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -173,7 +173,7 @@ def test_format_request_tool_message(): tru_result = OpenAIModel.format_request_tool_message(tool_result) exp_result = { - "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], + "content": '4\n["4"]', "role": "tool", "tool_call_id": "c1", } @@ -197,6 +197,130 @@ def test_format_request_tool_message_single_text_returns_string(): assert tru_result == exp_result +def test_format_request_tool_message_multi_text_returns_joined_string(): + """Test that multi-content text results are joined into a single string. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1696. + OpenAI-compatible endpoints (e.g., Kimi K2.5, vLLM, Ollama) only correctly + parse string content for tool messages; array format causes hallucinated results. + """ + tool_result = { + "content": [ + {"text": "Temperature: 72°F"}, + {"json": {"humidity": 45, "unit": "%"}}, + {"text": "Wind: 5 mph"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": 'Temperature: 72°F\n{"humidity": 45, "unit": "%"}\nWind: 5 mph', + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_tool_message_mixed_text_image_preserves_order(): + """Test that text and image content blocks preserve their original order.""" + tool_result = { + "content": [ + {"text": "Before image"}, + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + {"text": "After image"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + # Array format since images are present + assert isinstance(content, list) + assert len(content) == 3 + # Order preserved: text, image, text + assert content[0] == {"type": "text", "text": "Before image"} + assert content[1]["type"] == "image_url" + assert content[2] == {"type": "text", "text": "After image"} + + +def test_format_request_tool_message_merges_adjacent_text(): + """Test that adjacent text blocks are merged while non-text order is preserved.""" + tool_result = { + "content": [ + {"text": "Line 1"}, + {"text": "Line 2"}, + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + {"text": "Line 3"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 3 + # Adjacent text merged, image order preserved + assert content[0] == {"type": "text", "text": "Line 1\nLine 2"} + assert content[1]["type"] == "image_url" + assert content[2] == {"type": "text", "text": "Line 3"} + + +def test_format_request_tool_message_image_only(): + """Test tool message with only non-text content.""" + tool_result = { + "content": [ + {"image": {"format": "png", "source": {"bytes": b"PNG"}}}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 1 + assert content[0]["type"] == "image_url" + + +def test_format_request_tool_message_document_mixed(): + """Test tool message with document content mixed with text.""" + tool_result = { + "content": [ + {"text": "Summary"}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": b"PDF"}}}, + {"text": "Footer"}, + ], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + content = tru_result["content"] + assert isinstance(content, list) + assert len(content) == 3 + assert content[0] == {"type": "text", "text": "Summary"} + assert content[1]["type"] == "file" + assert content[2] == {"type": "text", "text": "Footer"} + + +def test_format_request_tool_message_empty_content(): + """Test tool message with empty content list returns empty string.""" + tool_result = { + "content": [], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIModel.format_request_tool_message(tool_result) + assert tru_result["content"] == "" + assert tru_result["role"] == "tool" + assert tru_result["tool_call_id"] == "c1" + + def test_split_tool_message_images_with_image(): """Test that images are extracted from tool messages.""" tool_message = { From d03311a00ca6ec507e4dc91671278d7ce1f7280b Mon Sep 17 00:00:00 2001 From: BHUKYA VENKATESH Date: Tue, 17 Mar 2026 02:00:05 +0530 Subject: [PATCH 185/279] feat: widen openai dependency to support 2.x for litellm compatibility (#1793) Co-authored-by: BV-Venky --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e07f3bac4..f0719d39e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,15 +46,15 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] -litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<1.110.0"] +litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] mistral = ["mistralai>=1.8.2"] ollama = ["ollama>=0.4.8,<1.0.0"] -openai = ["openai>=1.68.0,<2.0.0"] +openai = ["openai>=1.68.0,<3.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", - "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface + "openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ From 566e5ada67d1e96234cea3f41b991c346f4defaf Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:44:17 -0400 Subject: [PATCH 186/279] fix: typeError when serializing multimodal prompts with binary content in Graph/Swarm session persistence (#1870) --- src/strands/multiagent/graph.py | 5 ++-- src/strands/multiagent/swarm.py | 5 ++-- tests/strands/multiagent/test_graph.py | 40 ++++++++++++++++++++++++-- tests/strands/multiagent/test_swarm.py | 40 ++++++++++++++++++++++++-- 4 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 40a49cf7c..04d158108 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -51,6 +51,7 @@ from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput +from ..types.session import decode_bytes_values, encode_bytes_values from ..types.traces import AttributeValue from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -1158,7 +1159,7 @@ def serialize_state(self) -> dict[str, Any]: "interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_nodes_to_execute": next_nodes, - "current_task": self.state.task, + "current_task": encode_bytes_values(self.state.task), "execution_order": [n.node_id for n in self.state.execution_order], "_internal_state": { "interrupt_state": self._interrupt_state.to_dict(), @@ -1248,7 +1249,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] # Task - self.state.task = payload.get("current_task", self.state.task) + self.state.task = decode_bytes_values(payload.get("current_task", self.state.task)) # next nodes to execute next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes] diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 10e0da515..ed447eb07 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -51,6 +51,7 @@ from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput +from ..types.session import decode_bytes_values, encode_bytes_values from ..types.traces import AttributeValue from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -965,7 +966,7 @@ def serialize_state(self) -> dict[str, Any]: "node_history": [n.node_id for n in self.state.node_history], "node_results": {k: v.to_dict() for k, v in self.state.results.items()}, "next_nodes_to_execute": next_nodes, - "current_task": self.state.task, + "current_task": encode_bytes_values(self.state.task), "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, "handoff_node": self.state.handoff_node.node_id if self.state.handoff_node else None, @@ -1028,7 +1029,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None: logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) raise self.state.results = results - self.state.task = payload.get("current_task", self.state.task) + self.state.task = decode_bytes_values(payload.get("current_task", self.state.task)) next_node_ids = payload.get("next_nodes_to_execute") or [] if next_node_ids: diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8013021df..e978701cd 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1986,7 +1986,10 @@ async def stream_without_result(*args, **kwargs): @pytest.mark.asyncio async def test_graph_persisted(mock_strands_tracer, mock_use_span): - """Test graph persistence functionality.""" + """Test graph persistence functionality with multimodal input containing binary bytes.""" + import base64 + import json + # Create mock session manager session_manager = Mock(spec=FileSessionManager) session_manager.read_multi_agent().return_value = None @@ -2011,7 +2014,40 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): assert "completed_nodes" in state assert "node_results" in state - # Test apply_state_from_dict with persisted state + # Build a multimodal prompt with inline binary PDF bytes (the problematic case) + pdf_bytes = b"%PDF-1.4 binary content" + multimodal_task = [ + {"text": "Analyze this PDF"}, + { + "document": { + "format": "pdf", + "name": "document.pdf", + "source": { + "bytes": pdf_bytes, + }, + } + }, + ] + + # Simulate graph having executed with a multimodal task + graph.state.task = multimodal_task + + # serialize_state must not raise TypeError for bytes + serialized = graph.serialize_state() + assert json.dumps(serialized) # must be JSON-serializable + + # The bytes should be encoded in the serialized form + encoded_bytes = serialized["current_task"][1]["document"]["source"]["bytes"] + assert encoded_bytes == {"__bytes_encoded__": True, "data": base64.b64encode(pdf_bytes).decode()} + + # deserialize_state must restore bytes back to original + serialized["next_nodes_to_execute"] = ["test_node"] + serialized["status"] = "executing" + graph.deserialize_state(serialized) + restored_bytes = graph.state.task[1]["document"]["source"]["bytes"] + assert restored_bytes == pdf_bytes + + # Test apply_state_from_dict with plain string persisted state (backward compat) persisted_state = { "status": "executing", "completed_nodes": [], diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 491adc7c3..43acd6400 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1106,7 +1106,10 @@ async def failing_execute_swarm(*args, **kwargs): @pytest.mark.asyncio async def test_swarm_persistence(mock_strands_tracer, mock_use_span): - """Test swarm persistence functionality.""" + """Test swarm persistence functionality with multimodal input containing binary bytes.""" + import base64 + import json + # Create mock session manager session_manager = Mock(spec=FileSessionManager) session_manager.read_multi_agent.return_value = None @@ -1127,7 +1130,40 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert "node_results" in state assert "context" in state - # Test apply_state_from_dict with persisted state + # Build a multimodal prompt with inline binary PDF bytes (the problematic case) + pdf_bytes = b"%PDF-1.4 binary content" + multimodal_task = [ + {"text": "Analyze this PDF"}, + { + "document": { + "format": "pdf", + "name": "document.pdf", + "source": { + "bytes": pdf_bytes, + }, + } + }, + ] + + # Simulate swarm having executed with a multimodal task + swarm.state.task = multimodal_task + + # serialize_state must not raise TypeError for bytes + serialized = swarm.serialize_state() + assert json.dumps(serialized) # must be JSON-serializable + + # The bytes should be encoded in the serialized form + encoded_bytes = serialized["current_task"][1]["document"]["source"]["bytes"] + assert encoded_bytes == {"__bytes_encoded__": True, "data": base64.b64encode(pdf_bytes).decode()} + + # deserialize_state must restore bytes back to original + serialized["next_nodes_to_execute"] = ["test_agent"] + serialized["status"] = "executing" + swarm.deserialize_state(serialized) + restored_bytes = swarm.state.task[1]["document"]["source"]["bytes"] + assert restored_bytes == pdf_bytes + + # Test apply_state_from_dict with plain string persisted state (backward compat) persisted_state = { "status": "executing", "node_history": [], From 83ff4e011a802eb302087b2c01b5bfeb11475d66 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 18 Mar 2026 11:47:58 -0400 Subject: [PATCH 187/279] fix: lowercase the python language in code snippet (#1929) Co-authored-by: Mackenzie Zastrow --- src/strands/types/interrupt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py index d67148c5a..f76689762 100644 --- a/src/strands/types/interrupt.py +++ b/src/strands/types/interrupt.py @@ -16,7 +16,7 @@ ``` Example: - ```Python + ```python from typing import Any from strands import Agent, tool From 1643a624a6779d5c154bbd1092efe39494b29b51 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Mar 2026 16:04:58 -0400 Subject: [PATCH 188/279] fix: openai repsonses api error handling (#1931) --- src/strands/models/openai_responses.py | 8 ++++---- tests/strands/models/test_openai_responses.py | 20 +++++++++++++++++++ tests_integ/models/test_model_openai.py | 7 +++---- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 71d3f7ef7..bc2dcfd0e 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -294,14 +294,14 @@ async def stream( if hasattr(event, "response") and hasattr(event.response, "usage"): final_usage = event.response.usage break - except openai.BadRequestError as e: + except openai.APIError as e: if hasattr(e, "code") and e.code == "context_length_exceeded": logger.warning(_CONTEXT_WINDOW_OVERFLOW_MSG) raise ContextWindowOverflowException(str(e)) from e + if isinstance(e, openai.RateLimitError): + logger.warning(_RATE_LIMIT_MSG) + raise ModelThrottledException(str(e)) from e raise - except openai.RateLimitError as e: - logger.warning(_RATE_LIMIT_MSG) - raise ModelThrottledException(str(e)) from e # Close current content block if we had any if data_type: diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 9c84f4ed4..545f128bf 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -653,6 +653,26 @@ async def test_stream_context_overflow_exception(openai_client, model, messages) assert exc_info.value.__cause__ == mock_error +@pytest.mark.asyncio +async def test_stream_context_overflow_exception_api_error_type(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.APIError( + message="This model's maximum context length is 4096 tokens.", + request=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + @pytest.mark.asyncio async def test_stream_rate_limit_as_throttle(openai_client, model, messages): """Test that rate limit errors are converted to ModelThrottledException.""" diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index bccf2d82b..6b0b3a95b 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -225,16 +225,15 @@ def _rate_limit_params(): return params -@pytest.mark.parametrize("model_class,model_id", _rate_limit_params()) -def test_rate_limit_throttling_integration_no_retries(model_class, model_id): +def test_rate_limit_throttling_integration_no_retries(): """Integration test for rate limit handling with retries disabled. This test verifies that when a request exceeds OpenAI's rate limits, the model properly raises a ModelThrottledException. We disable retries to avoid waiting for the exponential backoff during testing. """ - model = model_class( - model_id=model_id, + model = OpenAIModel( + model_id="gpt-4o", client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, From adfeb97cc8528249d75e51245d95bfcbfd71cd00 Mon Sep 17 00:00:00 2001 From: stephentreacy <77759878+stephentreacy@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:51:35 +0000 Subject: [PATCH 189/279] fix(event-loop): ensure all cycle metrics include end time and duration (#1903) Co-authored-by: Stephen Treacy --- src/strands/event_loop/event_loop.py | 3 +- tests/strands/event_loop/test_event_loop.py | 48 +++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3b1e2d76a..2e8e4a660 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -560,8 +560,9 @@ async def _handle_tool_execution( if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + if invocation_state["request_state"].get("stop_event_loop", False) or structured_output_context.stop_loop: - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) yield EventLoopStopEvent( stop_reason, message, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0cabeaeee..cedca269b 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1084,3 +1084,51 @@ async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): ], "role": "user", } + + +@pytest.mark.asyncio +async def test_event_loop_metrics_recorded_before_recursion( + agent, + model, + tool, + agenerator, + alist, +): + model.stream.side_effect = [ + agenerator( + [ + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "t1", + "name": tool.tool_spec["name"], + }, + }, + }, + }, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + ), + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "test text"}}}, + {"contentBlockStop": {}}, + ] + ), + ] + + with unittest.mock.patch.object(agent.event_loop_metrics, "end_cycle") as mock_end_cycle: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={"request_state": {}}, + ) + events = await alist(stream) + + # Verify end_cycle was called once for tool cycle, once for text cycle + assert mock_end_cycle.call_count == 2 + + # Verify the event loop completed successfully + tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] + assert tru_stop_reason == "end_turn" From ae283971070fd78feab60e79bec019c1cadd5045 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Thu, 19 Mar 2026 17:46:50 -0400 Subject: [PATCH 190/279] fix: pin mistralai upper bound (#1935) --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f0719d39e..59e24dac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] -mistral = ["mistralai>=1.8.2"] +mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] openai = ["openai>=1.68.0,<3.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] @@ -73,8 +73,8 @@ a2a = [ ] bidi = [ - "aws_sdk_bedrock_runtime; python_version>='3.12'", - "smithy-aws-core>=0.0.1; python_version>='3.12'", + "aws_sdk_bedrock_runtime>=0.4.0,<1.0.0; python_version>='3.12'", + "smithy-aws-core>=0.4.0,<1.0.0; python_version>='3.12'", ] bidi-io = [ "prompt_toolkit>=3.0.0,<4.0.0", From 38c1ab6325423b3c8b2d3e6a2896deceea084dd4 Mon Sep 17 00:00:00 2001 From: atian8179 Date: Fri, 20 Mar 2026 21:41:48 +0800 Subject: [PATCH 191/279] fix: override end_turn stop reason when streaming response contains toolUse blocks (#1827) Co-authored-by: atian8179 Co-authored-by: Nicholas Clegg --- src/strands/event_loop/streaming.py | 21 +++++- src/strands/models/bedrock.py | 31 +------- tests/strands/event_loop/test_streaming.py | 85 +++++++++++++++++++++- tests/strands/models/test_bedrock.py | 47 ------------ 4 files changed, 104 insertions(+), 80 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index b7d85ca30..ee45420fe 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -324,16 +324,31 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: return state -def handle_message_stop(event: MessageStopEvent) -> StopReason: +def handle_message_stop(event: MessageStopEvent, content: list[dict[str, Any]]) -> StopReason: """Handles the end of a message by returning the stop reason. + Some models return "end_turn" even when tool calls are present, which prevents the event loop from processing + those tool calls. This function overrides to "tool_use" so tool execution proceeds correctly. + Args: event: Stop event. + content: The message content blocks accumulated during streaming. Returns: The reason for stopping the stream. """ - return event["stopReason"] + stop_reason = event["stopReason"] + + if stop_reason == "end_turn" and any("toolUse" in item for item in content): + logger.warning( + "original_stop_reason=<%s>, new_stop_reason=<%s> | " + "overriding stop reason due to toolUse blocks in response", + "end_turn", + "tool_use", + ) + stop_reason = "tool_use" + + return stop_reason def handle_redact_content(event: RedactContentEvent, state: dict[str, Any]) -> None: @@ -427,7 +442,7 @@ async def process_stream( elif "contentBlockStop" in chunk: state = handle_content_block_stop(state) elif "messageStop" in chunk: - stop_reason = handle_message_stop(chunk["messageStop"]) + stop_reason = handle_message_stop(chunk["messageStop"], state["message"].get("content", [])) elif "metadata" in chunk: time_to_first_byte_ms = ( int(1000 * (first_byte_time - start_time)) if (start_time and first_byte_time) else None diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bab4031ed..5de34a6c2 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -823,8 +823,6 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) - # Track tool use events to fix stopReason for streaming responses - has_tool_use = False for chunk in response["stream"]: if ( "metadata" in chunk @@ -836,24 +834,7 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - # Track if we see tool use events - if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): - has_tool_use = True - - # Fix stopReason for streaming responses that contain tool use - if ( - has_tool_use - and "messageStop" in chunk - and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" - ): - # Create corrected chunk with tool_use stopReason - modified_chunk = chunk.copy() - modified_chunk["messageStop"] = message_stop.copy() - modified_chunk["messageStop"]["stopReason"] = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - callback(modified_chunk) - else: - callback(chunk) + callback(chunk) else: response = self.client.converse(**request) @@ -992,17 +973,9 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera yield {"contentBlockStop": {}} # Yield messageStop event - # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side - current_stop_reason = response["stopReason"] - if current_stop_reason == "end_turn": - message_content = response["output"]["message"]["content"] - if any("toolUse" in content for content in message_content): - current_stop_reason = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - yield { "messageStop": { - "stopReason": current_stop_reason, + "stopReason": response["stopReason"], "additionalModelResponseFields": response.get("additionalModelResponseFields"), } } diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 6d376450a..bfaf796d2 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -530,12 +530,30 @@ def test_handle_content_block_stop(state, exp_updated_state): def test_handle_message_stop(): event: MessageStopEvent = {"stopReason": "end_turn"} - tru_reason = strands.event_loop.streaming.handle_message_stop(event) + tru_reason = strands.event_loop.streaming.handle_message_stop(event, []) exp_reason = "end_turn" assert tru_reason == exp_reason +def test_handle_message_stop_overrides_end_turn_when_tool_use_present(): + event: MessageStopEvent = {"stopReason": "end_turn"} + content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}] + + tru_reason = strands.event_loop.streaming.handle_message_stop(event, content) + + assert tru_reason == "tool_use" + + +def test_handle_message_stop_keeps_tool_use_unchanged(): + event: MessageStopEvent = {"stopReason": "tool_use"} + content = [{"toolUse": {"toolUseId": "t1", "name": "myTool", "input": {}}}] + + tru_reason = strands.event_loop.streaming.handle_message_stop(event, content) + + assert tru_reason == "tool_use" + + def test_extract_usage_metrics(): event = { "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -1334,3 +1352,68 @@ async def test_stream_messages_normalizes_messages(agenerator, alist): {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, {"content": [{"toolUse": {"name": "INVALID_TOOL_NAME"}}], "role": "assistant"}, ] + + +@pytest.mark.asyncio +async def test_process_stream_overrides_end_turn_when_tool_use_present(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"key": "val"}'}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "tool_use" + + +@pytest.mark.asyncio +async def test_process_stream_keeps_end_turn_when_no_tool_use(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "Hello!"}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "end_turn" + + +@pytest.mark.asyncio +async def test_process_stream_keeps_tool_use_stop_reason_unchanged(agenerator, alist): + response = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"contentBlockIndex": 0, "start": {"toolUse": {"toolUseId": "t1", "name": "myTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + "metrics": {"latencyMs": 100}, + } + }, + ] + + stream = strands.event_loop.streaming.process_stream(agenerator(response)) + last_event = cast(ModelStopReason, (await alist(stream))[-1]) + + assert last_event["stop"][0] == "tool_use" diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 89c4df70d..5f81efd24 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1565,53 +1565,6 @@ async def test_stream_logging(bedrock_client, model, messages, caplog, alist): assert "finished streaming response from model" in log_text -@pytest.mark.asyncio -async def test_stream_stop_reason_override_streaming(bedrock_client, model, messages, alist): - """Test that stopReason is overridden from end_turn to tool_use in streaming mode when tool use is detected.""" - bedrock_client.converse_stream.return_value = { - "stream": [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test_tool"}}}}, - {"contentBlockDelta": {"delta": {"test": {"input": '{"param": "value"}'}}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "end_turn"}}, - ] - } - - response = model.stream(messages) - events = await alist(response) - - # Find the messageStop event - message_stop_event = next(event for event in events if "messageStop" in event) - - # Verify stopReason was overridden to tool_use - assert message_stop_event["messageStop"]["stopReason"] == "tool_use" - - -@pytest.mark.asyncio -async def test_stream_stop_reason_override_non_streaming(bedrock_client, alist, messages): - """Test that stopReason is overridden from end_turn to tool_use in non-streaming mode when tool use is detected.""" - bedrock_client.converse.return_value = { - "output": { - "message": { - "role": "assistant", - "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"param": "value"}}}], - } - }, - "stopReason": "end_turn", - } - - model = BedrockModel(model_id="test-model", streaming=False) - response = model.stream(messages) - events = await alist(response) - - # Find the messageStop event - message_stop_event = next(event for event in events if "messageStop" in event) - - # Verify stopReason was overridden to tool_use - assert message_stop_event["messageStop"]["stopReason"] == "tool_use" - - def test_format_request_cleans_tool_result_content_blocks(model, model_id): messages = [ { From 80fdd94b258e6b1335d0c3b33b7959667c6ba4ad Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 20 Mar 2026 11:38:18 -0400 Subject: [PATCH 192/279] fix: summarization conversation manager sometimes returns empty response (#1947) --- src/strands/tools/_tool_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/_tool_helpers.py b/src/strands/tools/_tool_helpers.py index d023caeec..3b62337d3 100644 --- a/src/strands/tools/_tool_helpers.py +++ b/src/strands/tools/_tool_helpers.py @@ -6,14 +6,14 @@ # https://github.com/strands-agents/sdk-python/issues/998 @tool(name="noop", description="This is a fake tool that MUST be completely ignored.") -def noop_tool() -> None: +def noop_tool() -> str: """No-op tool to satisfy tool spec requirement when tool messages are present. Some model providers (e.g., Bedrock) will return an error response if tool uses and tool results are present in messages without any tool specs configured. Consequently, if the summarization agent has no registered tools, summarization will fail. As a workaround, we register the no-op tool. """ - pass + return "You MUST NOT use this tool. Respond DIRECTLY to the user." def generate_missing_tool_result_content(tool_use_ids: list[str]) -> list[ContentBlock]: From fd8168a531c140a0082a3c6412a577fe81db21f0 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 20 Mar 2026 11:51:59 -0400 Subject: [PATCH 193/279] fix: remove agent from swarm test to get more consistency out of it (#1946) --- tests_integ/test_multiagent_swarm.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index a244bf753..8ccfa5c89 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -88,11 +88,11 @@ def __init__(self): self.should_exit = True def register_hooks(self, registry): - registry.add_callback(BeforeNodeCallEvent, self.exit_before_analyst) + registry.add_callback(BeforeNodeCallEvent, self.exit_before_writer) - def exit_before_analyst(self, event): - if event.node_id == "analyst" and self.should_exit: - raise SystemExit("Controlled exit before analyst") + def exit_before_writer(self, event): + if event.node_id == "writer" and self.should_exit: + raise SystemExit("Controlled exit before writer") return ExitHook() @@ -365,32 +365,30 @@ def test_swarm_resume_from_executing_state(tmpdir, exit_hook, verify_hook): # First execution - exit before second node session_manager = FileSessionManager(session_id=session_id, storage_dir=tmpdir) researcher = Agent(name="researcher", system_prompt="you are a researcher.") - analyst = Agent(name="analyst", system_prompt="you are an analyst.") writer = Agent(name="writer", system_prompt="you are a writer.") - swarm = Swarm([researcher, analyst, writer], session_manager=session_manager, hooks=[exit_hook]) + swarm = Swarm([researcher, writer], session_manager=session_manager, hooks=[exit_hook]) try: - swarm("write AI trends and calculate growth in 100 words") + swarm("write AI trends in 100 words") except SystemExit as e: - assert "Controlled exit before analyst" in str(e) + assert "Controlled exit before writer" in str(e) # Verify state was persisted with EXECUTING status and next node persisted_state = session_manager.read_multi_agent(session_id, swarm.id) assert persisted_state["status"] == "executing" assert len(persisted_state["node_history"]) == 1 assert persisted_state["node_history"][0] == "researcher" - assert persisted_state["next_nodes_to_execute"] == ["analyst"] + assert persisted_state["next_nodes_to_execute"] == ["writer"] exit_hook.should_exit = False researcher2 = Agent(name="researcher", system_prompt="you are a researcher.") - analyst2 = Agent(name="analyst", system_prompt="you are an analyst.") writer2 = Agent(name="writer", system_prompt="you are a writer.") - new_swarm = Swarm([researcher2, analyst2, writer2], session_manager=session_manager, hooks=[verify_hook]) - result = new_swarm("write AI trends and calculate growth in 100 words") + new_swarm = Swarm([researcher2, writer2], session_manager=session_manager, hooks=[verify_hook]) + result = new_swarm("write AI trends in 100 words") - # Verify swarm behavior - should resume from analyst, not restart + # Verify swarm behavior - should resume from writer, not restart assert result.status.value == "completed" - assert verify_hook.first_node == "analyst" + assert verify_hook.first_node == "writer" node_ids = [n.node_id for n in result.node_history] - assert "analyst" in node_ids + assert "writer" in node_ids From 0a723bcbc3be162cc3cbd0094f905d7839c4032e Mon Sep 17 00:00:00 2001 From: Uday Mehta Date: Tue, 24 Mar 2026 22:44:05 +0530 Subject: [PATCH 194/279] fix: CRITICAL: Hard pin `litellm<=1.82.6` to mitigate supply chain attack (#1961) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 59e24dac3..e1ab0d7d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] -litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<3.0.0"] +litellm = ["litellm>=1.75.9,<=1.82.6", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] From a1101498d933731087f7a39a41ce583a5d4ba640 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 25 Mar 2026 10:44:31 -0400 Subject: [PATCH 195/279] chore: remove Cohere from required integ test providers (#1967) Co-authored-by: Mackenzie Zastrow --- tests_integ/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index dbe25d685..347b22a43 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -202,7 +202,6 @@ def _load_api_keys_from_secrets_manager(): required_providers = { "ANTHROPIC_API_KEY", - "COHERE_API_KEY", "MISTRAL_API_KEY", "OPENAI_API_KEY", "WRITER_API_KEY", From 6a35add1f77bc83708907791ae551bc40170c86d Mon Sep 17 00:00:00 2001 From: notowen333 <51685858+notowen333@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:46:47 -0400 Subject: [PATCH 196/279] feat: add AgentAsTool (#1932) Co-authored-by: Owen Kaplan --- src/strands/agent/_agent_as_tool.py | 296 ++++++++ src/strands/agent/agent.py | 36 + src/strands/tools/executors/_executor.py | 6 + src/strands/types/_events.py | 26 + tests/strands/agent/test_agent.py | 50 ++ tests/strands/agent/test_agent_as_tool.py | 676 ++++++++++++++++++ .../strands/tools/executors/test_executor.py | 51 ++ tests/strands/types/test__events.py | 37 + tests_integ/test_agent_as_tool.py | 36 + 9 files changed, 1214 insertions(+) create mode 100644 src/strands/agent/_agent_as_tool.py create mode 100644 tests/strands/agent/test_agent_as_tool.py create mode 100644 tests_integ/test_agent_as_tool.py diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py new file mode 100644 index 000000000..11b536789 --- /dev/null +++ b/src/strands/agent/_agent_as_tool.py @@ -0,0 +1,296 @@ +"""Agent-as-tool adapter. + +This module provides the _AgentAsTool class that wraps an Agent as a tool +so it can be passed to another agent's tool list. +""" + +from __future__ import annotations + +import copy +import logging +import threading +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent +from ..types.content import Messages +from ..types.interrupt import InterruptResponseContent +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse + +if TYPE_CHECKING: + from .agent import Agent + +logger = logging.getLogger(__name__) + + +class _AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + + researcher = Agent(name="researcher", description="Finds information") + + # Use via convenience method (default: fresh conversation each call) + tool = researcher.as_tool() + + # Preserve context across invocations + tool = researcher.as_tool(preserve_context=True) + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: Agent, + *, + name: str, + description: str | None = None, + preserve_context: bool = False, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = ( + description or agent.description or f"Use the {name} agent as a tool by providing a natural language input" + ) + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + # Serialize access so _reset_agent_state + stream_async are atomic. + # threading.Lock (not asyncio.Lock) because run_async() may create + # separate event loops in different threads. + self._lock = threading.Lock() + + if not preserve_context: + if getattr(agent, "_session_manager", None) is not None: + raise ValueError( + "preserve_context=False cannot be used with an agent that has a session manager. " + "The session manager persists conversation history externally, which conflicts with " + "resetting the agent's state between invocations." + ) + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) + + @property + def agent(self) -> Agent: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + }, + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final + AgentResult is yielded as a ToolResultEvent. + + When the sub-agent encounters a hook interrupt (e.g. from BeforeToolCallEvent), + the interrupts are propagated to the parent agent via ToolInterruptEvent. On + resume, interrupt responses are forwarded to the sub-agent automatically. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + AgentAsToolStreamEvent for intermediate events, ToolInterruptEvent if the + sub-agent is interrupted, or ToolResultEvent with the final response. + """ + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + elif isinstance(tool_input, str): + prompt = tool_input + else: + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) + prompt = str(tool_input) + + tool_use_id = tool_use["toolUseId"] + + # Serialize access to the underlying agent. _reset_agent_state() mutates + # the agent before stream_async acquires its own lock, so a concurrent + # call would corrupt an in-flight invocation. + if not self._lock.acquire(blocking=False): + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent is already processing a request", + self._tool_name, + tool_use_id, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent '{self._tool_name}' is already processing a request"}], + } + ) + return + + try: + # Determine if we are resuming the sub-agent from an interrupt. + if self._is_sub_agent_interrupted(): + prompt = self._build_interrupt_responses() + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resuming sub-agent from interrupt", + self._tool_name, + tool_use_id, + ) + elif not self._preserve_context: + self._reset_agent_state(tool_use_id) + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield AgentAsToolStreamEvent(tool_use, event, self) + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + # Propagate sub-agent interrupts to the parent agent. + if result.stop_reason == "interrupt" and result.interrupts: + yield ToolInterruptEvent(tool_use, list(result.interrupts)) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + finally: + self._lock.release() + + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + + def _is_sub_agent_interrupted(self) -> bool: + """Check whether the wrapped agent is in an activated interrupt state.""" + return self._agent._interrupt_state.activated + + def _build_interrupt_responses(self) -> list[InterruptResponseContent]: + """Build interrupt response payloads from the sub-agent's interrupt state. + + The parent agent's ``_interrupt_state.resume()`` sets ``.response`` on the shared + ``Interrupt`` objects (registered by the executor), so we re-package them in the + format expected by ``Agent.stream_async``. + + Returns: + List of interrupt response content blocks for resuming the sub-agent. + """ + return [ + {"interruptResponse": {"interruptId": interrupt.id, "response": interrupt.response}} + for interrupt in self._agent._interrupt_state.interrupts.values() + if interrupt.response is not None + ] + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..6adecce31 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -61,7 +61,9 @@ from ..types.agent import AgentInput, ConcurrentInvocationMode from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException +from ..types.tools import AgentTool from ..types.traces import AttributeValue +from ._agent_as_tool import _AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -612,6 +614,40 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu finally: await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + def as_tool( + self, + *, + name: str | None = None, + description: str | None = None, + preserve_context: bool = False, + ) -> AgentTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description, or a + generic description if the agent has no description set. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + + Returns: + A tool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + return _AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) + def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 0da6b5715..5825b3cdb 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -226,6 +226,12 @@ async def _stream( # ToolStreamEvent and the last event is just the result. if isinstance(event, ToolInterruptEvent): + # Register any interrupts not already in the agent's state. + # For normal hooks this is a no-op (already registered by _Interruptible.interrupt()). + # For sub-agent interrupts propagated via _AgentAsTool, this is where they get + # registered so that _interrupt_state.resume() can locate them by ID. + for interrupt in event.interrupts: + agent._interrupt_state.interrupts.setdefault(interrupt.id, interrupt) yield event return diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..1d5a5de79 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import _AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -323,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating _AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "_AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The _AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "_AgentAsTool": + """The _AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..2ce9ff245 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -16,6 +16,7 @@ import strands from strands import Agent, Plugin, ToolContext from strands.agent import AgentResult +from strands.agent._agent_as_tool import _AgentAsTool from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2699,3 +2700,52 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an _AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, _AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..f5848b315 --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,676 @@ +"""Tests for _AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent._agent_as_tool import _AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.interrupt import Interrupt, _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import AgentAsToolStreamEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + agent._interrupt_state = _InterruptState() + return agent + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + +@pytest.fixture +def tool(mock_agent): + return _AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init(mock_agent): + tool = _AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) + assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent + + +def test_init_description_defaults_to_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Agent that researches topics" + + +def test_init_description_defaults_to_generic_when_agent_has_none(fake_agent): + tool = _AgentAsTool(fake_agent, name="researcher", preserve_context=True) + assert tool._description == "Use the researcher agent as a tool by providing a natural language input" + + +def test_init_description_explicit_overrides_agent_description(fake_agent): + fake_agent.description = "Agent that researches topics" + tool = _AgentAsTool(fake_agent, name="researcher", description="custom", preserve_context=True) + assert tool._description == "custom" + + +def test_init_preserve_context_defaults_false(fake_agent): + tool = _AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False + + +def test_init_preserve_context_true(mock_agent): + tool = _AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True + + +# --- properties --- + + +def test_tool_properties(tool): + assert tool.tool_name == "test_agent" + assert tool.tool_type == "agent" + + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" + + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating _AgentAsTool + assert event.agent_as_tool is tool + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +# --- preserve_context --- + + +@pytest.mark.asyncio +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 + + +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) + + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 + + +@pytest.mark.asyncio +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] + + +@pytest.mark.asyncio +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} + + +@pytest.mark.asyncio +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc") + + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" + + +def test_preserve_context_false_rejects_session_manager(fake_agent): + """preserve_context=False should raise ValueError when agent has a session manager.""" + fake_agent._session_manager = MagicMock() + + with pytest.raises(ValueError, match="cannot be used with an agent that has a session manager"): + _AgentAsTool(fake_agent, name="t", description="d", preserve_context=False) + + +# --- interrupt propagation --- + + +@pytest.fixture +def interrupt_result(): + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval") + return AgentResult( + stop_reason="interrupt", + message={"role": "assistant", "content": [{"text": "pending"}]}, + metrics=EventLoopMetrics(), + state={}, + interrupts=[interrupt], + ) + + +@pytest.mark.asyncio +async def test_stream_interrupt_yields_tool_interrupt_event(tool, mock_agent, tool_use, interrupt_result): + """When the sub-agent returns an interrupt result, _AgentAsTool should yield ToolInterruptEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + assert events[0].interrupts == interrupt_result.interrupts + assert events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_interrupt_no_tool_result_appended(tool, mock_agent, tool_use, interrupt_result): + """ToolInterruptEvent should not produce a ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events == [] + + +@pytest.mark.asyncio +async def test_stream_interrupt_forwards_intermediate_events(tool, mock_agent, tool_use, interrupt_result): + """Intermediate events should still be yielded before the interrupt.""" + intermediate = [{"data": "partial"}] + mock_agent.stream_async.return_value = _mock_stream_async(interrupt_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + interrupt_events = [e for e in events if isinstance(e, ToolInterruptEvent)] + assert len(stream_events) == 1 + assert len(interrupt_events) == 1 + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_forwards_responses(fake_agent): + """On resume, _AgentAsTool should forward interrupt responses to the sub-agent.""" + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + + # Put the sub-agent in an activated interrupt state with the response already set + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "approved"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + + events = [event async for event in tool.stream(tool_use, {})] + + # Should have called stream_async with interrupt responses, not the original prompt + call_args = fake_agent.stream_async.call_args + agent_input = call_args[0][0] + assert isinstance(agent_input, list) + assert len(agent_input) == 1 + assert agent_input[0]["interruptResponse"]["interruptId"] == "interrupt-1" + assert agent_input[0]["interruptResponse"]["response"] == "APPROVE" + + # Should produce a normal result + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_interrupt_resume_skips_state_reset(fake_agent): + """When resuming from interrupt with preserve_context=False, state reset should be skipped.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("key", "value") + + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Simulate the sub-agent being in interrupt state after a previous invocation + interrupt = Interrupt(id="interrupt-1", name="approval", reason="need approval", response="APPROVE") + fake_agent._interrupt_state.interrupts["interrupt-1"] = interrupt + fake_agent._interrupt_state.activate() + + # Mutate messages to simulate sub-agent progress before interrupt + fake_agent.messages.append({"role": "assistant", "content": [{"text": "working on it"}]}) + + normal_result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "done"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + fake_agent.stream_async = MagicMock(return_value=_mock_stream_async(normal_result)) + + tool_use = {"toolUseId": "tool-123", "name": "fake_agent", "input": {"input": "do something"}} + async for _ in tool.stream(tool_use, {}): + pass + + # Messages should NOT have been reset — the sub-agent needs its conversation history intact + assert len(fake_agent.messages) == 2 + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_false_by_default(tool): + """_is_sub_agent_interrupted returns False when no interrupts are active.""" + assert tool._is_sub_agent_interrupted() is False + + +@pytest.mark.asyncio +async def test_is_sub_agent_interrupted_true_when_activated(fake_agent): + """_is_sub_agent_interrupted returns True when the sub-agent's interrupt state is activated.""" + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + assert tool._is_sub_agent_interrupted() is False + + fake_agent._interrupt_state.activate() + assert tool._is_sub_agent_interrupted() is True + + +@pytest.mark.asyncio +async def test_build_interrupt_responses(fake_agent): + """_build_interrupt_responses packages sub-agent interrupts into response content blocks.""" + tool = _AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + interrupt_a = Interrupt(id="id-a", name="a", reason="r", response="yes") + interrupt_b = Interrupt(id="id-b", name="b", reason="r", response=None) + fake_agent._interrupt_state.interrupts = {"id-a": interrupt_a, "id-b": interrupt_b} + + responses = tool._build_interrupt_responses() + + # Only interrupt_a has a response + assert len(responses) == 1 + assert responses[0] == {"interruptResponse": {"interruptId": "id-a", "response": "yes"}} + + +# --- concurrency --- + + +@pytest.mark.asyncio +async def test_stream_rejects_concurrent_call(tool, mock_agent, tool_use, agent_result): + """A second concurrent call should get an error ToolResultEvent.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + # Simulate the lock already being held by another invocation + tool._lock.acquire() + try: + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert isinstance(events[0], ToolResultEvent) + assert events[0]["tool_result"]["status"] == "error" + assert "already processing" in events[0]["tool_result"]["content"][0]["text"] + mock_agent.stream_async.assert_not_called() + finally: + tool._lock.release() + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_completion(tool, mock_agent, tool_use, agent_result): + """Lock should be released after stream completes, allowing subsequent calls.""" + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() + + # A second call should succeed + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + + +@pytest.mark.asyncio +async def test_stream_releases_lock_after_error(tool, mock_agent, tool_use): + """Lock should be released even when the agent raises an exception.""" + mock_agent.stream_async.side_effect = RuntimeError("boom") + + async for _ in tool.stream(tool_use, {}): + pass + + assert not tool._lock.locked() diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 4a5479503..297aa66f3 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -464,6 +464,57 @@ async def test_executor_stream_tool_interrupt_resume(executor, agent, tool_resul assert tru_results == exp_results +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_registers_on_agent( + executor, agent, tool_results, invocation_state, alist +): + """ToolInterruptEvent from a tool should register interrupts in the agent's _interrupt_state.""" + # Create a tool that yields a ToolInterruptEvent with an interrupt NOT pre-registered on the agent + # (simulates _AgentAsTool propagating sub-agent interrupts). + foreign_interrupt = Interrupt(id="sub-agent-interrupt-1", name="approval", reason="need approval") + + @strands.tool(name="agent_tool") + def agent_tool_func(): + return "unused" + + async def mock_stream(_tool_use, _invocation_state, **_kwargs): + yield ToolInterruptEvent(_tool_use, [foreign_interrupt]) + + agent_tool_func.stream = mock_stream + agent.tool_registry.register_tool(agent_tool_func) + + tool_use: ToolUse = {"name": "agent_tool", "toolUseId": "test_tool_id", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + events = await alist(stream) + + # Should yield the interrupt event + assert len(events) == 1 + assert isinstance(events[0], ToolInterruptEvent) + + # The interrupt should now be registered on the agent's _interrupt_state + assert "sub-agent-interrupt-1" in agent._interrupt_state.interrupts + assert agent._interrupt_state.interrupts["sub-agent-interrupt-1"] is foreign_interrupt + + +@pytest.mark.asyncio +async def test_executor_stream_tool_interrupt_does_not_overwrite_existing( + executor, agent, tool_results, invocation_state, alist +): + """setdefault should not overwrite interrupts already in the agent's state (normal hook case).""" + tool_use = {"name": "interrupt_tool", "toolUseId": "test_tool_id", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + # The interrupt_tool hook registered the interrupt via _Interruptible.interrupt(). + # The executor's setdefault should have been a no-op for this pre-registered interrupt. + registered = agent._interrupt_state.interrupts + assert len(registered) == 1 + interrupt = next(iter(registered.values())) + assert interrupt.name == "test_name" + assert interrupt.reason == "test reason" + + @pytest.mark.asyncio async def test_executor_stream_updates_invocation_state_with_agent( executor, agent, tool_results, invocation_state, weather_tool, alist diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1 From 521c4d7d2337837e22183cbca9a08bd81a65a404 Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Fri, 27 Mar 2026 13:38:39 -0400 Subject: [PATCH 197/279] feat: auto-wrap Agent instances passed in tools list (#1997) Co-authored-by: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Co-authored-by: agent-of-mkmeral --- src/strands/agent/agent.py | 3 +- src/strands/tools/registry.py | 8 ++ tests/strands/agent/test_agent_as_tool.py | 46 ++++++++++++ tests/strands/tools/test_registry.py | 89 +++++++++++++++++++++++ 4 files changed, 145 insertions(+), 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6adecce31..20ae9b309 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -153,7 +153,8 @@ def __init__( - Imported Python modules (e.g., from strands_tools import current_time) - Dictionaries with name/path keys (e.g., {"name": "tool_name", "path": "/path/to/tool.py"}) - ToolProvider instances for managed tool collections - - Functions decorated with `@strands.tool` decorator. + - Functions decorated with `@strands.tool` decorator + - Agent instances (auto-wrapped via `agent.as_tool()` with defaults) If provided, only these tools will be available. If None, all tools will be available. system_prompt: System prompt to guide model behavior. diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index a5e4132bb..9a0f0f722 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -19,6 +19,7 @@ from typing_extensions import TypedDict from .._async import run_async +from ..agent.base import AgentBase from ..tools.decorator import DecoratedFunctionTool from ..types.tools import AgentTool, ToolSpec from . import ToolProvider @@ -62,6 +63,7 @@ def process_tools(self, tools: list[Any]) -> list[str]: 3. A module for a module based tool 4. Instances of AgentTool (@tool decorated functions) 5. Dictionaries with name/path keys (deprecated) + 6. Agent instances with an ``as_tool()`` method (auto-wrapped) Returns: @@ -140,6 +142,12 @@ async def get_tools() -> Sequence[AgentTool]: for provider_tool in provider_tools: self.register_tool(provider_tool) tool_names.append(provider_tool.tool_name) + # Agent instances - auto-wrap with .as_tool() for convenience + elif isinstance(tool, AgentBase) and hasattr(tool, "as_tool") and callable(tool.as_tool): + wrapped_tool = tool.as_tool() + self.register_tool(wrapped_tool) + tool_names.append(wrapped_tool.tool_name) + else: logger.warning("tool=<%s> | unrecognized tool specification", tool) diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py index f5848b315..5a8399830 100644 --- a/tests/strands/agent/test_agent_as_tool.py +++ b/tests/strands/agent/test_agent_as_tool.py @@ -674,3 +674,49 @@ async def test_stream_releases_lock_after_error(tool, mock_agent, tool_use): pass assert not tool._lock.locked() + + +# --- Agent-as-tool sugar (passing agents directly in tools list) --- + + +def test_agent_passed_directly_in_tools_list(): + """Test that an Agent can be passed directly in another Agent's tools list.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="research_agent", description="Does research", callback_handler=None) + + # This should work without calling .as_tool() explicitly + parent_agent = Agent(name="orchestrator", tools=[sub_agent], callback_handler=None) + + assert "research_agent" in parent_agent.tool_names + + +def test_multiple_agents_passed_directly_in_tools_list(): + """Test that multiple Agents can be passed directly in another Agent's tools list.""" + from strands.agent.agent import Agent + + agent_a = Agent(name="agent_a", callback_handler=None) + agent_b = Agent(name="agent_b", callback_handler=None) + + parent = Agent(name="parent", tools=[agent_a, agent_b], callback_handler=None) + + assert "agent_a" in parent.tool_names + assert "agent_b" in parent.tool_names + + +def test_agent_mixed_with_regular_tools_in_tools_list(): + """Test that Agents can be mixed with regular tools in the tools list.""" + from strands import tool as tool_decorator + from strands.agent.agent import Agent + + @tool_decorator + def my_tool(x: str) -> str: + """A regular tool.""" + return x + + sub_agent = Agent(name="helper_agent", callback_handler=None) + + parent = Agent(name="parent", tools=[my_tool, sub_agent], callback_handler=None) + + assert "my_tool" in parent.tool_names + assert "helper_agent" in parent.tool_names diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 73141beb6..3723f381b 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -603,3 +603,92 @@ def test_tool_registry_replace_non_dynamic_with_dynamic(): assert registry.registry["my_tool"] == new_tool assert registry.dynamic_tools["my_tool"] == new_tool + + +# --- Agent-as-tool sugar --- + + +def test_process_tools_with_agent_instance(): + """Test that passing an Agent instance in tools list auto-wraps it with as_tool().""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="research_agent", description="Finds information", callback_handler=None) + + registry = ToolRegistry() + tool_names = registry.process_tools([sub_agent]) + + assert "research_agent" in tool_names + assert "research_agent" in registry.registry + assert registry.registry["research_agent"].tool_type == "agent" + + +def test_process_tools_with_agent_instance_uses_agent_name(): + """Test that the auto-wrapped tool uses the agent's name.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="my_custom_agent", callback_handler=None) + + registry = ToolRegistry() + registry.process_tools([sub_agent]) + + assert "my_custom_agent" in registry.registry + spec = registry.registry["my_custom_agent"].tool_spec + assert spec["name"] == "my_custom_agent" + + +def test_process_tools_with_agent_instance_uses_agent_description(): + """Test that the auto-wrapped tool uses the agent's description.""" + from strands.agent.agent import Agent + + sub_agent = Agent(name="helper", description="A helpful assistant", callback_handler=None) + + registry = ToolRegistry() + registry.process_tools([sub_agent]) + + spec = registry.registry["helper"].tool_spec + assert spec["description"] == "A helpful assistant" + + +def test_process_tools_with_agent_in_nested_list(): + """Test that Agent instances in nested iterables are auto-wrapped.""" + from strands.agent.agent import Agent + + agent_a = Agent(name="agent_a", callback_handler=None) + agent_b = Agent(name="agent_b", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([[agent_a, agent_b]])) + + assert tool_names == ["agent_a", "agent_b"] + + +def test_process_tools_with_mixed_agents_and_tools(): + """Test that Agent instances can be mixed with regular tools.""" + from strands.agent.agent import Agent + + def function() -> str: + return "done" + + regular_tool = tool(name="regular_tool")(function) + sub_agent = Agent(name="sub_agent", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([regular_tool, sub_agent])) + + assert tool_names == ["regular_tool", "sub_agent"] + assert registry.registry["sub_agent"].tool_type == "agent" + + +def test_process_tools_with_multiple_agents(): + """Test that multiple Agent instances can be passed.""" + from strands.agent.agent import Agent + + agent_1 = Agent(name="researcher", description="Does research", callback_handler=None) + agent_2 = Agent(name="writer", description="Writes content", callback_handler=None) + agent_3 = Agent(name="reviewer", description="Reviews work", callback_handler=None) + + registry = ToolRegistry() + tool_names = sorted(registry.process_tools([agent_1, agent_2, agent_3])) + + assert tool_names == ["researcher", "reviewer", "writer"] + assert all(registry.registry[name].tool_type == "agent" for name in tool_names) From 194c69b42ea5b537707f20f3acf3348d539a91b5 Mon Sep 17 00:00:00 2001 From: Sanjeed <40694326+sanjeed5@users.noreply.github.com> Date: Sat, 28 Mar 2026 02:01:40 +0530 Subject: [PATCH 198/279] feat(telemetry): emit system prompt on chat spans per GenAI semconv (#1818) Co-authored-by: sanjeed5 Co-authored-by: Liz <91279165+lizradway@users.noreply.github.com> --- src/strands/event_loop/event_loop.py | 2 + src/strands/telemetry/tracer.py | 43 +++++++++++ tests/strands/event_loop/test_event_loop.py | 3 + tests/strands/telemetry/test_tracer.py | 80 +++++++++++++++++++-- 4 files changed, 123 insertions(+), 5 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 2e8e4a660..eb664e056 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -313,6 +313,8 @@ async def _handle_model_execution( parent_span=cycle_span, model_id=model_id, custom_trace_attributes=agent.trace_attributes, + system_prompt=agent.system_prompt, + system_prompt_content=agent._system_prompt_content, ) with trace_api.use_span(model_invoke_span, end_on_exit=True): await agent.hooks.invoke_callbacks_async( diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 0471a7fcc..c03d9d962 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -285,6 +285,8 @@ def start_model_invoke_span( parent_span: Span | None = None, model_id: str | None = None, custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + system_prompt: str | None = None, + system_prompt_content: list | None = None, **kwargs: Any, ) -> Span: """Start a new span for a model invocation. @@ -294,6 +296,8 @@ def start_model_invoke_span( parent_span: Optional parent span to link this span to. model_id: Optional identifier for the model being invoked. custom_trace_attributes: Optional mapping of custom trace attributes to include in the span. + system_prompt: Optional system prompt string provided to the model. + system_prompt_content: Optional list of system prompt content blocks. **kwargs: Additional attributes to add to the span. Returns: @@ -311,6 +315,7 @@ def start_model_invoke_span( attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL) + self._add_system_prompt_event(span, system_prompt, system_prompt_content) self._add_event_messages(span, messages) return span @@ -813,6 +818,44 @@ def _get_common_attributes( ) return dict(common_attributes) + def _add_system_prompt_event( + self, + span: Span, + system_prompt: str | None = None, + system_prompt_content: list | None = None, + ) -> None: + """Emit system prompt as a span event per OTel GenAI semantic conventions. + + In legacy mode (v1.36), emits a ``gen_ai.system.message`` event. + In latest experimental mode, emits ``gen_ai.system_instructions`` on the + ``gen_ai.client.inference.operation.details`` event, since Strands passes + system instructions separately from chat history. + + Args: + span: The span to add the event to. + system_prompt: Optional system prompt string. + system_prompt_content: Optional list of system prompt content blocks. + """ + if system_prompt is None and system_prompt_content is None: + return + + content_blocks = system_prompt_content if system_prompt_content else [{"text": system_prompt}] + + if self.use_latest_genai_conventions: + parts = self._map_content_blocks_to_otel_parts(content_blocks) + self._add_event( + span, + "gen_ai.client.inference.operation.details", + {"gen_ai.system_instructions": serialize(parts)}, + to_span_attributes=self.is_langfuse, + ) + else: + self._add_event( + span, + "gen_ai.system.message", + {"content": serialize(content_blocks)}, + ) + def _add_event_messages(self, span: Span, messages: Messages) -> None: """Adds messages as event to the provided span based on the current GenAI conventions. diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index cedca269b..3ffb89e7c 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -547,6 +547,9 @@ async def test_event_loop_cycle_creates_spans( mock_get_tracer.assert_called_once() mock_tracer.start_event_loop_cycle_span.assert_called_once() mock_tracer.start_model_invoke_span.assert_called_once() + call_kwargs = mock_tracer.start_model_invoke_span.call_args[1] + assert call_kwargs["system_prompt"] == agent.system_prompt + assert call_kwargs["system_prompt_content"] == agent._system_prompt_content mock_tracer.end_model_invoke_span.assert_called_once() mock_tracer.end_event_loop_cycle_span.assert_called_once() diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 410db0c0c..9176ce4ae 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -140,9 +140,14 @@ def test_start_model_invoke_span(mock_tracer): messages = [{"role": "user", "content": [{"text": "Hello"}]}] model_id = "test-model" custom_attrs = {"custom_key": "custom_value", "user_id": "12345"} + system_prompt = "You are a helpful assistant" span = tracer.start_model_invoke_span( - messages=messages, agent_name="TestAgent", model_id=model_id, custom_trace_attributes=custom_attrs + messages=messages, + agent_name="TestAgent", + model_id=model_id, + custom_trace_attributes=custom_attrs, + system_prompt=system_prompt, ) mock_tracer.start_span.assert_called_once() @@ -158,9 +163,14 @@ def test_start_model_invoke_span(mock_tracer): "agent_name": "TestAgent", } ) - mock_span.add_event.assert_called_with( - "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.system.message", + attributes={"content": serialize([{"text": system_prompt}])}, ) + assert calls[1] == mock.call("gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])}) assert span is not None @@ -184,8 +194,11 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): }, ] model_id = "test-model" + system_prompt = "You are a calculator assistant" - span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id) + span = tracer.start_model_invoke_span( + messages=messages, agent_name="TestAgent", model_id=model_id, system_prompt=system_prompt + ) mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "chat" @@ -199,7 +212,16 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): "agent_name": "TestAgent", } ) - mock_span.add_event.assert_called_with( + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.system_instructions": serialize([{"type": "text", "content": system_prompt}]), + }, + ) + assert calls[1] == mock.call( "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( @@ -226,6 +248,54 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): assert span is not None +def test_start_model_invoke_span_without_system_prompt(mock_tracer): + """Test that no system prompt event is emitted when system_prompt is None.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + span = tracer.start_model_invoke_span(messages=messages, model_id="test-model") + + assert mock_span.add_event.call_count == 1 + mock_span.add_event.assert_called_once_with( + "gen_ai.user.message", attributes={"content": json.dumps(messages[0]["content"])} + ) + assert span is not None + + +def test_start_model_invoke_span_with_system_prompt_content(mock_tracer): + """Test that system_prompt_content takes priority over system_prompt string.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are helpful"}, {"text": "Be concise"}] + + span = tracer.start_model_invoke_span( + messages=messages, + model_id="test-model", + system_prompt="ignored string", + system_prompt_content=system_prompt_content, + ) + + calls = mock_span.add_event.call_args_list + assert len(calls) == 2 + assert calls[0] == mock.call( + "gen_ai.system.message", + attributes={"content": serialize(system_prompt_content)}, + ) + assert span is not None + + def test_end_model_invoke_span(mock_span): """Test ending a model invoke span.""" tracer = Tracer() From e2b60364677e0fca967d3247aa3cb5890bc67e78 Mon Sep 17 00:00:00 2001 From: Christian-kam Date: Mon, 30 Mar 2026 20:58:54 +0200 Subject: [PATCH 199/279] feat(mcp): add support for MCP elicitation -32042 error handling (#1745) Co-authored-by: Christian Kamwangala --- src/strands/tools/mcp/mcp_client.py | 29 +++- tests/strands/tools/mcp/test_mcp_client.py | 149 +++++++++++++++++++++ 2 files changed, 177 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 51a627c7c..1fd2990ec 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -10,6 +10,7 @@ import asyncio import base64 import contextvars +import json import logging import threading import uuid @@ -24,8 +25,10 @@ import anyio from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT +from mcp.shared.exceptions import McpError from mcp.types import ( BlobResourceContents, + ElicitationRequiredErrorData, GetPromptResult, ListPromptsResult, ListResourcesResult, @@ -668,7 +671,31 @@ async def call_tool_async( return self._handle_tool_execution_error(tool_use_id, e) def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult: - """Create error ToolResult with consistent logging.""" + """Create error ToolResult with consistent logging and elicitation callback support. + + Args: + tool_use_id: Unique identifier for this tool use. + exception: The exception that occurred during tool execution. + + Returns: + MCPToolResult: Error result containing either the elicitation data or the + original exception message. + """ + if isinstance(exception, McpError) and exception.error.code == -32042: + try: + error_data = ElicitationRequiredErrorData.model_validate(exception.error.data) + elicitations = [e.model_dump(exclude_none=True) for e in error_data.elicitations] + + return MCPToolResult( + status="error", + toolUseId=tool_use_id, + content=[ + {"text": (f"MCP Elicitation required: [{str(exception)}] with data {json.dumps(elicitations)}")} + ], + ) + except Exception: + logger.debug("Failed to parse ElicitationRequiredErrorData from -32042 error", exc_info=True) + return MCPToolResult( status="error", toolUseId=tool_use_id, diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 5eedd1e33..057c41a95 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -928,3 +928,152 @@ async def test_handle_error_message_with_percent_in_message(): # This should not raise TypeError and should not raise the exception (since it's non-fatal) await client._handle_error_message(error_with_percent) + + +def test_call_tool_sync_elicitation_error(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"] + assert "elicit-123" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_multiple_urls(mock_transport, mock_session): + """Test that call_tool_sync correctly handles elicitation errors with multiple elicitations.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth1", message="First authorization", elicitationId="elicit-1" + ), + ElicitRequestURLParams( + url="https://example.com/auth2", message="Second authorization", elicitationId="elicit-2" + ), + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth1" in result["content"][0]["text"] + assert "https://example.com/auth2" in result["content"][0]["text"] + assert "First authorization" in result["content"][0]["text"] + assert "Second authorization" in result["content"][0]["text"] + assert "elicit-1" in result["content"][0]["text"] + assert "elicit-2" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_no_urls(mock_transport, mock_session): + """Test that -32042 error with empty URL still returns generic elicitation result.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ElicitRequestURLParams(url="", message="No URL provided", elicitationId="elicit-1")] + ) + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "elicit-1" in result["content"][0]["text"] + assert "No URL provided" in result["content"][0]["text"] + + +def test_call_tool_sync_other_mcp_error_code(mock_transport, mock_session): + """Test that non-32042 McpError falls through to generic error.""" + from mcp.shared.exceptions import McpError + + error = McpError(error=MagicMock(code=-32600, message="Invalid request")) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "Tool execution failed" in result["content"][0]["text"] + + +def test_call_tool_sync_elicitation_error_malformed_data(mock_transport, mock_session): + """Test that -32042 with unparseable data falls through to generic error.""" + from mcp.shared.exceptions import McpError + + error = McpError(error=MagicMock(code=-32042, data={"garbage": True})) + mock_session.call_tool.side_effect = error + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + assert result["status"] == "error" + assert "Tool execution failed" in result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_call_tool_async_elicitation_error(mock_transport, mock_session): + """Test that call_tool_async correctly handles elicitation required errors.""" + from mcp.shared.exceptions import McpError + from mcp.types import ElicitationRequiredErrorData, ElicitRequestURLParams + + elicitation_data = ElicitationRequiredErrorData( + elicitations=[ + ElicitRequestURLParams( + url="https://example.com/auth", message="Please authorize the application", elicitationId="elicit-123" + ) + ] + ) + + error = McpError(error=MagicMock(code=-32042, data=elicitation_data.model_dump())) + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + async def mock_awaitable(): + raise error + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + ) + + assert result["status"] == "error" + assert result["toolUseId"] == "test-123" + assert len(result["content"]) == 1 + assert "MCP Elicitation required" in result["content"][0]["text"] + assert "https://example.com/auth" in result["content"][0]["text"] + assert "Please authorize the application" in result["content"][0]["text"] + assert "elicit-123" in result["content"][0]["text"] From 424224d6bef0789302644876faf0b745bdbc5876 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:48:43 -0400 Subject: [PATCH 200/279] fix: ollama input/output token count (#2008) --- src/strands/models/ollama.py | 4 ++-- tests/strands/models/test_ollama.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 68aba59d4..97cb7948a 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -273,8 +273,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: return { "metadata": { "usage": { - "inputTokens": event["data"].eval_count, - "outputTokens": event["data"].prompt_eval_count, + "inputTokens": event["data"].prompt_eval_count, + "outputTokens": event["data"].eval_count, "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, }, "metrics": { diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index d17894028..0d4fbb9e0 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -394,8 +394,8 @@ def test_format_chunk_metadata(model): exp_chunk = { "metadata": { "usage": { - "inputTokens": 100, - "outputTokens": 50, + "inputTokens": 50, + "outputTokens": 100, "totalTokens": 150, }, "metrics": { @@ -438,7 +438,7 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings {"messageStop": {"stopReason": "end_turn"}}, { "metadata": { - "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, "metrics": {"latencyMs": 1.0}, } }, @@ -510,7 +510,7 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): {"messageStop": {"stopReason": "tool_use"}}, { "metadata": { - "usage": {"inputTokens": 15, "outputTokens": 8, "totalTokens": 23}, + "usage": {"inputTokens": 8, "outputTokens": 15, "totalTokens": 23}, "metrics": {"latencyMs": 2.0}, } }, From ae19308d0e4e3114f81a2b3776c8afb43d5646ba Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 30 Mar 2026 18:48:27 -0400 Subject: [PATCH 201/279] feat: add stateful model support for server-side conversation management (#2004) --- src/strands/agent/agent.py | 23 +++++- src/strands/event_loop/event_loop.py | 1 + src/strands/event_loop/streaming.py | 3 + src/strands/models/model.py | 47 +++++++++++- src/strands/models/openai_responses.py | 49 ++++++++---- src/strands/multiagent/graph.py | 9 +++ src/strands/multiagent/swarm.py | 5 ++ .../session/repository_session_manager.py | 26 +++++-- src/strands/types/session.py | 3 + tests/strands/agent/test_agent.py | 9 ++- tests/strands/agent/test_agent_model_state.py | 69 +++++++++++++++++ tests/strands/event_loop/test_event_loop.py | 2 + tests/strands/event_loop/test_streaming.py | 5 ++ .../test_streaming_structured_output.py | 2 + tests/strands/models/test_model.py | 40 ++++++++++ tests/strands/models/test_openai_responses.py | 51 +++++++++++++ tests/strands/multiagent/test_graph.py | 2 + tests/strands/multiagent/test_swarm.py | 1 + .../test_repository_session_manager.py | 56 ++++++++++++++ tests/strands/types/test_session.py | 15 +++- tests_integ/models/test_model_mantle.py | 74 +++++++++++++++++++ tests_integ/models/test_model_openai.py | 21 ++++++ tests_integ/test_session.py | 35 +++++++++ 23 files changed, 521 insertions(+), 27 deletions(-) create mode 100644 tests/strands/agent/test_agent_model_state.py create mode 100644 tests_integ/models/test_model_mantle.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 20ae9b309..3a23133de 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -45,7 +45,7 @@ from ..hooks.registry import TEvent from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel -from ..models.model import Model +from ..models.model import Model, _ModelPlugin from ..plugins import Plugin from ..plugins.registry import _PluginRegistry from ..session.session_manager import SessionManager @@ -68,6 +68,7 @@ from .base import AgentBase from .conversation_manager import ( ConversationManager, + NullConversationManager, SlidingWindowConversationManager, ) from .state import AgentState @@ -229,7 +230,19 @@ def __init__( else: self.callback_handler = callback_handler - self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() + if self.model.stateful and conversation_manager is not None: + raise ValueError( + "conversation_manager cannot be used with a stateful model. " + "The model manages conversation state server-side." + ) + + self.conversation_manager: ConversationManager + if self.model.stateful: + self.conversation_manager = NullConversationManager() + elif conversation_manager: + self.conversation_manager = conversation_manager + else: + self.conversation_manager = SlidingWindowConversationManager() # Process trace attributes to ensure they're of compatible types self.trace_attributes: dict[str, AttributeValue] = {} @@ -282,6 +295,9 @@ def __init__( self._interrupt_state = _InterruptState() + # Runtime state for model providers (e.g., server-side response ids) + self._model_state: dict[str, Any] = {} + # Initialize lock for guarding concurrent invocations # Using threading.Lock instead of asyncio.Lock because run_async() creates # separate event loops in different threads, so asyncio.Lock wouldn't work @@ -327,6 +343,9 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + # Register built-in plugins + self._plugin_registry.add_and_init(_ModelPlugin()) + if plugins: for plugin in plugins: self._plugin_registry.add_and_init(plugin) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index eb664e056..374cfe129 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -338,6 +338,7 @@ async def _handle_model_execution( system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, invocation_state=invocation_state, + model_state=agent._model_state, cancel_signal=agent._cancel_signal, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index ee45420fe..0a1161135 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -463,6 +463,7 @@ async def stream_messages( tool_choice: Any | None = None, system_prompt_content: list[SystemContentBlock] | None = None, invocation_state: dict[str, Any] | None = None, + model_state: dict[str, Any] | None = None, cancel_signal: threading.Event | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: @@ -477,6 +478,7 @@ async def stream_messages( system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. + model_state: Runtime state for model providers (e.g., server-side response ids). cancel_signal: Optional threading.Event to check for cancellation during streaming. **kwargs: Additional keyword arguments for future extensibility. @@ -495,6 +497,7 @@ async def stream_messages( tool_choice=tool_choice, system_prompt_content=system_prompt_content, invocation_state=invocation_state, + model_state=model_state, ) async for event in process_stream(chunks, start_time, cancel_signal): diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 9d83a72eb..f084d24d5 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -4,14 +4,19 @@ import logging from collections.abc import AsyncGenerator, AsyncIterable from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar from pydantic import BaseModel +from ..hooks.events import AfterInvocationEvent +from ..plugins.plugin import Plugin from ..types.content import Messages, SystemContentBlock from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec +if TYPE_CHECKING: + from ..agent.agent import Agent + logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) @@ -37,6 +42,15 @@ class Model(abc.ABC): standardized way to configure and process requests for different AI model providers. """ + @property + def stateful(self) -> bool: + """Whether the model manages conversation state server-side. + + Returns: + False by default. Model providers that support server-side state should override this. + """ + return False + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: @@ -115,3 +129,34 @@ def stream( ModelThrottledException: When the model service is throttling requests from the client. """ pass + + +class _ModelPlugin(Plugin): + """Plugin that manages model-related lifecycle hooks.""" + + @property + def name(self) -> str: + """A stable string identifier for this plugin.""" + return "strands:model" + + @staticmethod + def _on_after_invocation(event: AfterInvocationEvent) -> None: + """Handle post-invocation model management tasks. + + Performs the following: + - Clears messages when the model is managing conversation state server-side. + """ + if event.agent.model.stateful: + event.agent.messages.clear() + logger.debug( + "response_id=<%s> | cleared messages for server-managed conversation", + event.agent._model_state.get("response_id"), + ) + + def init_agent(self, agent: "Agent") -> None: + """Register model lifecycle hooks with the agent. + + Args: + agent: The agent instance to register hooks with. + """ + agent.add_hook(self._on_after_invocation, AfterInvocationEvent) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index bc2dcfd0e..01974c11d 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -1,18 +1,8 @@ """OpenAI model provider using the Responses API. -The Responses API is OpenAI's newer API that differs from the Chat Completions API in several key ways: +Note: Built-in tools (web search, code interpreter, file search) are not yet supported. -1. The Responses API can maintain conversation state server-side through "previous_response_id", - while Chat Completions is stateless and requires sending full conversation history each time. - Note: This implementation currently only implements the stateless approach. - -2. Responses API uses "input" (list of items) instead of "messages", and system - prompts are passed as "instructions" rather than a system role message. - -3. Responses API supports built-in tools (web search, code interpreter, file search) - Note: These are not yet implemented in this provider. - -- Docs: https://platform.openai.com/docs/api-reference/responses +Docs: https://platform.openai.com/docs/api-reference/responses """ import base64 @@ -132,10 +122,14 @@ class OpenAIResponsesConfig(TypedDict, total=False): params: Model parameters (e.g., max_output_tokens, temperature, etc.). For a complete list of supported parameters, see https://platform.openai.com/docs/api-reference/responses/create. + stateful: Whether to enable server-side conversation state management. + When True, the server stores conversation history and the client does not need to + send the full message history with each request. Defaults to False. """ model_id: str params: dict[str, Any] | None + stateful: bool def __init__( self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig] @@ -153,6 +147,15 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) + @property + @override + def stateful(self) -> bool: + """Whether server-side conversation storage is enabled. + + Derived from the ``stateful`` configuration option. + """ + return bool(self.config.get("stateful")) + @override def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override] """Update the OpenAI Responses API model configuration with the provided arguments. @@ -180,6 +183,7 @@ async def stream( system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, + model_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the OpenAI Responses API model. @@ -189,6 +193,7 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + model_state: Runtime state for model providers (e.g., server-side response ids). **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -199,7 +204,7 @@ async def stream( ModelThrottledException: If the request is throttled by OpenAI (rate limits). """ logger.debug("formatting request for OpenAI Responses API") - request = self._format_request(messages, tool_specs, system_prompt, tool_choice) + request = self._format_request(messages, tool_specs, system_prompt, tool_choice, model_state) logger.debug("formatted request=<%s>", request) logger.debug("invoking OpenAI Responses API model") @@ -219,7 +224,14 @@ async def stream( async for event in response: if hasattr(event, "type"): - if event.type == "response.reasoning_text.delta": + if event.type == "response.created": + # Capture response id for server-side conversation chaining + if hasattr(event, "response"): + response_id = getattr(event.response, "id", None) + if model_state is not None and response_id: + model_state["response_id"] = response_id + + elif event.type == "response.reasoning_text.delta": # Reasoning content streaming (for o1/o3 reasoning models) chunks, data_type = self._stream_switch_content("reasoning_content", data_type) for chunk in chunks: @@ -383,6 +395,7 @@ def _format_request( tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + model_state: dict[str, Any] | None = None, ) -> dict[str, Any]: """Format an OpenAI Responses API compatible response streaming request. @@ -391,6 +404,7 @@ def _format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + model_state: Runtime state for model providers (e.g., server-side response ids). Returns: An OpenAI Responses API compatible response streaming request. @@ -400,13 +414,18 @@ def _format_request( format. """ input_items = self._format_request_messages(messages) - request = { + request: dict[str, Any] = { "model": self.config["model_id"], "input": input_items, "stream": True, **cast(dict[str, Any], self.config.get("params", {})), + "store": self.stateful, } + response_id = model_state.get("response_id") if model_state else None + if response_id and self.stateful: + request["previous_response_id"] = response_id + if system_prompt: request["instructions"] = system_prompt diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 04d158108..8da8314ea 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -170,6 +170,7 @@ class GraphNode: execution_time: int = 0 _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) + _initial_model_state: dict[str, Any] = field(default_factory=dict, init=False) def __post_init__(self) -> None: """Capture initial executor state after initialization.""" @@ -180,6 +181,9 @@ def __post_init__(self) -> None: if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"): self._initial_state = AgentState(self.executor.state.get()) + if hasattr(self.executor, "_model_state"): + self._initial_model_state = copy.deepcopy(self.executor._model_state) + def reset_executor_state(self) -> None: """Reset GraphNode executor state to initial state when graph was created. @@ -192,6 +196,9 @@ def reset_executor_state(self) -> None: if hasattr(self.executor, "state"): self.executor.state = AgentState(self._initial_state.get()) + if hasattr(self.executor, "_model_state"): + self.executor._model_state = copy.deepcopy(self._initial_model_state) + # Reset execution status self.execution_status = Status.PENDING self.result = None @@ -639,6 +646,7 @@ def _activate_interrupt( "interrupt_state": node.executor._interrupt_state.to_dict(), "state": node.executor.state.get(), "messages": node.executor.messages, + "model_state": node.executor._model_state, } ) @@ -1074,6 +1082,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: node.executor.messages = node_context["messages"] node.executor.state = AgentState(node_context["state"]) node.executor._interrupt_state = _InterruptState.from_dict(node_context["interrupt_state"]) + node.executor._model_state = node_context.get("model_state", {}) return node_responses diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index ed447eb07..f5731a371 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -69,12 +69,14 @@ class SwarmNode: swarm: Optional["Swarm"] = None _initial_messages: Messages = field(default_factory=list, init=False) _initial_state: AgentState = field(default_factory=AgentState, init=False) + _initial_model_state: dict[str, Any] = field(default_factory=dict, init=False) def __post_init__(self) -> None: """Capture initial executor state after initialization.""" # Deep copy the initial messages and state to preserve them self._initial_messages = copy.deepcopy(self.executor.messages) self._initial_state = AgentState(self.executor.state.get()) + self._initial_model_state = copy.deepcopy(self.executor._model_state) def __hash__(self) -> int: """Return hash for SwarmNode based on node_id.""" @@ -104,10 +106,12 @@ def reset_executor_state(self) -> None: self.executor.messages = context["messages"] self.executor.state = AgentState(context["state"]) self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"]) + self.executor._model_state = context.get("model_state", {}) return self.executor.messages = copy.deepcopy(self._initial_messages) self.executor.state = AgentState(self._initial_state.get()) + self.executor._model_state = copy.deepcopy(self._initial_model_state) @dataclass @@ -697,6 +701,7 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M "interrupt_state": node.executor._interrupt_state.to_dict(), "state": node.executor.state.get(), "messages": node.executor.messages, + "model_state": node.executor._model_state, } self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 7e538c08b..c1032a85e 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -114,6 +114,7 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: current_state_version = agent.state._get_version() current_interrupt_state_version = agent._interrupt_state._get_version() current_conversation_manager_state = agent.conversation_manager.get_state() + current_model_state = agent._model_state # Check if we have a previous state to compare against last_synced = self._last_synced_internal_state.get(agent.agent_id) @@ -126,7 +127,9 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: conversation_manager_state_changed = True else: state_changed = current_state_version != last_synced.get("state_version") - internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version") + internal_state_changed = current_interrupt_state_version != last_synced.get( + "interrupt_state_version" + ) or current_model_state != last_synced.get("model_state") conversation_manager_state_changed = current_conversation_manager_state != last_synced.get( "conversation_manager_state" ) @@ -160,6 +163,7 @@ def sync_agent(self, agent: "Agent", **kwargs: Any) -> None: "state_version": current_state_version, "interrupt_state_version": current_interrupt_state_version, "conversation_manager_state": copy.deepcopy(current_conversation_manager_state), + "model_state": copy.deepcopy(current_model_state), } def initialize(self, agent: "Agent", **kwargs: Any) -> None: @@ -220,11 +224,21 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: if len(session_messages) > 0: self._latest_agent_message[agent.agent_id] = session_messages[-1] - # Restore the agents messages array including the optional prepend messages - agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages] - - # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 - agent.messages = self._fix_broken_tool_use(agent.messages) + # Skip restoring messages when conversation is managed server-side + if agent.model.stateful: + logger.debug( + "agent_id=<%s> | session_id=<%s> | skipping message restore for server-managed conversation", + agent.agent_id, + self.session_id, + ) + else: + # Restore the agents messages array including the optional prepend messages + agent.messages = prepend_messages + [ + session_message.to_message() for session_message in session_messages + ] + + # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 + agent.messages = self._fix_broken_tool_use(agent.messages) self._is_new_session = False diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 29453f4b7..294c518d7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -134,6 +134,7 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": state=agent.state.get(), _internal_state={ "interrupt_state": agent._interrupt_state.to_dict(), + "model_state": agent._model_state, }, ) @@ -175,6 +176,8 @@ def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + if "model_state" in self._internal_state: + agent._model_state = self._internal_state["model_state"] def initialize_bidi_internal_state(self, agent: "BidiAgent") -> None: """Initialize internal state of BidiAgent. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2ce9ff245..5a3cce11c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -58,6 +58,7 @@ async def stream(*args, **kwargs): mock = unittest.mock.Mock(spec=getattr(request, "param", None)) mock.configure_mock(mock_stream=unittest.mock.MagicMock()) mock.stream.side_effect = stream + mock.stateful = False return mock @@ -358,6 +359,7 @@ def test_agent__call__( tool_choice=None, system_prompt_content=[{"text": system_prompt}], invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ), unittest.mock.call( [ @@ -397,6 +399,7 @@ def test_agent__call__( tool_choice=None, system_prompt_content=[{"text": system_prompt}], invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ), ], ) @@ -519,6 +522,7 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener tool_choice=None, system_prompt_content=unittest.mock.ANY, invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) conversation_manager_spy.reduce_context.assert_called_once() @@ -667,6 +671,7 @@ def test_agent__call__retry_with_overwritten_tool(mock_model, agent, tool, agene tool_choice=None, system_prompt_content=unittest.mock.ANY, invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) assert conversation_manager_spy.reduce_context.call_count == 2 @@ -1307,7 +1312,9 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model @pytest.mark.asyncio @unittest.mock.patch("strands.agent.agent.get_tracer") -async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle, alist): +async def test_agent_stream_async_creates_and_ends_span_on_success( + mock_get_tracer, mock_event_loop_cycle, mock_model, alist +): """Test that stream_async creates and ends a span when the call succeeds.""" # Setup mock tracer and span mock_tracer = unittest.mock.MagicMock() diff --git a/tests/strands/agent/test_agent_model_state.py b/tests/strands/agent/test_agent_model_state.py new file mode 100644 index 000000000..7e751d334 --- /dev/null +++ b/tests/strands/agent/test_agent_model_state.py @@ -0,0 +1,69 @@ +"""Tests for agent model state with server-side conversation management.""" + +import unittest.mock + +import pytest + +from strands.agent.agent import Agent +from strands.agent.conversation_manager import NullConversationManager, SlidingWindowConversationManager + + +@pytest.fixture +def mock_model(): + """Create a mock model that writes response_id to model_state.""" + model = unittest.mock.MagicMock() + model.config = {"model_id": "test-model"} + model.get_config.return_value = {"model_id": "test-model"} + + call_count = 0 + + async def mock_stream(messages, tool_specs=None, system_prompt=None, **kwargs): + nonlocal call_count + call_count += 1 + resp_id = "resp_abc123" if call_count == 1 else "resp_def456" + + model_state = kwargs.get("model_state") + if model_state is not None: + model_state["response_id"] = resp_id + + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "Hello"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + yield { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + } + + model.stream = unittest.mock.MagicMock(side_effect=mock_stream) + model.stateful = True + return model + + +def test_agent_model_state(mock_model): + """Verify model_state is populated, messages are cleared, and model_state is passed on subsequent calls.""" + agent = Agent(model=mock_model, callback_handler=None) + assert isinstance(agent.conversation_manager, NullConversationManager) + + agent("Turn 1") + assert agent._model_state.get("response_id") == "resp_abc123" + assert len(agent.messages) == 0 + + agent("Turn 2") + assert agent._model_state.get("response_id") == "resp_def456" + assert len(agent.messages) == 0 + + second_call_kwargs = mock_model.stream.call_args_list[1][1] + assert second_call_kwargs.get("model_state") is agent._model_state + + +def test_agent_model_state_raises_with_conversation_manager(): + """Passing a conversation_manager with a stateful model raises ValueError.""" + model = unittest.mock.MagicMock() + model.stateful = True + + with pytest.raises(ValueError, match="conversation_manager cannot be used with a stateful model"): + Agent(model=model, conversation_manager=SlidingWindowConversationManager()) diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 3ffb89e7c..2f57f5560 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -152,6 +152,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.tool_executor = tool_executor mock._interrupt_state = _InterruptState() mock._cancel_signal = threading.Event() + mock._model_state = {} mock.trace_attributes = {} mock.retry_strategy = ModelRetryStrategy() @@ -391,6 +392,7 @@ async def test_event_loop_cycle_tool_result( tool_choice=None, system_prompt_content=unittest.mock.ANY, invocation_state=unittest.mock.ANY, + model_state=unittest.mock.ANY, ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index bfaf796d2..93f8d95f8 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -1174,6 +1174,7 @@ async def test_stream_messages(agenerator, alist): tool_choice=None, system_prompt_content=[{"text": "test prompt"}], invocation_state=None, + model_state=None, ) @@ -1208,6 +1209,7 @@ async def test_stream_messages_with_system_prompt_content(agenerator, alist): tool_choice=None, system_prompt_content=system_prompt_content, invocation_state=None, + model_state=None, ) @@ -1242,6 +1244,7 @@ async def test_stream_messages_single_text_block_backwards_compatibility(agenera tool_choice=None, system_prompt_content=system_prompt_content, invocation_state=None, + model_state=None, ) @@ -1274,6 +1277,7 @@ async def test_stream_messages_empty_system_prompt_content(agenerator, alist): tool_choice=None, system_prompt_content=[], invocation_state=None, + model_state=None, ) @@ -1306,6 +1310,7 @@ async def test_stream_messages_none_system_prompt_content(agenerator, alist): tool_choice=None, system_prompt_content=None, invocation_state=None, + model_state=None, ) # Ensure that we're getting typed events coming out of process_stream diff --git a/tests/strands/event_loop/test_streaming_structured_output.py b/tests/strands/event_loop/test_streaming_structured_output.py index 4c4082c00..3c7358237 100644 --- a/tests/strands/event_loop/test_streaming_structured_output.py +++ b/tests/strands/event_loop/test_streaming_structured_output.py @@ -67,6 +67,7 @@ async def test_stream_messages_with_tool_choice(agenerator, alist): tool_choice=tool_choice, system_prompt_content=[{"text": "test prompt"}], invocation_state=None, + model_state=None, ) # Verify we get the expected events @@ -133,6 +134,7 @@ async def test_stream_messages_with_forced_structured_output(agenerator, alist): tool_choice=tool_choice, system_prompt_content=[{"text": "Extract user information"}], invocation_state=None, + model_state=None, ) assert len(tru_events) > 0 diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index b8249f504..458e98645 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -1,7 +1,11 @@ +from unittest.mock import MagicMock + import pytest from pydantic import BaseModel +from strands.hooks.events import AfterInvocationEvent from strands.models import Model as SAModel +from strands.models.model import _ModelPlugin class Person(BaseModel): @@ -67,6 +71,11 @@ def tool_specs(): ] +@pytest.fixture +def model_plugin(): + return _ModelPlugin() + + @pytest.fixture def system_prompt(): return "s1" @@ -173,3 +182,34 @@ async def stream(self, messages, tool_specs=None, system_prompt=None, *, tool_ch response = model.stream(messages, tool_specs, system_prompt) events = await alist(response) assert events[1]["contentBlockDelta"]["delta"]["text"] == "No tool choice" + + +def test_stateful_false(model): + """Model.stateful defaults to False.""" + assert not model.stateful + + +def test_model_plugin_clears_messages_when_stateful(model_plugin): + """Messages are cleared when model is stateful.""" + agent = MagicMock() + agent.model.stateful = True + agent._model_state = {"response_id": "resp_123"} + agent.messages = [{"role": "user", "content": [{"text": "hello"}]}] + + event = AfterInvocationEvent(agent=agent, invocation_state={}) + model_plugin._on_after_invocation(event) + + assert agent.messages == [] + + +def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): + """Messages are preserved when model is not stateful.""" + agent = MagicMock() + agent.model.stateful = False + agent._model_state = {} + agent.messages = [{"role": "user", "content": [{"text": "hello"}]}] + + event = AfterInvocationEvent(agent=agent, invocation_state={}) + model_plugin._on_after_invocation(event) + + assert len(agent.messages) == 1 diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 545f128bf..6a16eef89 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -328,6 +328,7 @@ def test_format_request(model, messages, tool_specs, system_prompt): } ], "stream": True, + "store": False, "instructions": system_prompt, "tools": [ { @@ -487,6 +488,7 @@ async def test_stream(openai_client, model_id, model, agenerator, alist): "model": model_id, "input": [{"role": "user", "content": [{"type": "input_text", "text": "test"}]}], "stream": True, + "store": False, "max_output_tokens": 100, } openai_client.responses.create.assert_called_once_with(**expected_request) @@ -955,3 +957,52 @@ def mock_valid_version(package_name: str) -> str: # Reload with valid version to restore module state with unittest.mock.patch("importlib.metadata.version", mock_valid_version): importlib.reload(openai_responses_module) + + +@pytest.mark.parametrize("stateful", [True, False]) +def test_stateful(model_id, stateful): + """Model.stateful reflects the stateful config option.""" + model = OpenAIResponsesModel(model_id=model_id, stateful=stateful) + assert model.stateful is stateful + + +@pytest.mark.asyncio +async def test_stream_stateful(openai_client, model_id, agenerator, alist): + """When stateful is enabled, model writes response_id to model_state from response.created.""" + model = OpenAIResponsesModel(model_id=model_id, stateful=True) + mock_events = [ + unittest.mock.Mock( + type="response.created", + response=unittest.mock.Mock(id="resp_abc123"), + ), + unittest.mock.Mock(type="response.output_text.delta", delta="Hi"), + unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock( + id="resp_abc123", + usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15), + ), + ), + ] + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=agenerator(mock_events)) + + model_state = {"response_id": "resp_previous"} + events = await alist( + model.stream( + [{"role": "user", "content": [{"text": "Hello"}]}], + model_state=model_state, + ) + ) + + call_kwargs = openai_client.responses.create.call_args[1] + assert call_kwargs["previous_response_id"] == "resp_previous" + + assert model_state["response_id"] == "resp_abc123" + + metadata_events = [e for e in events if "metadata" in e] + assert len(metadata_events) == 1 + assert metadata_events[0]["metadata"] == { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index e978701cd..a6085627c 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -26,6 +26,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.state = AgentState() agent.messages = [] agent._interrupt_state = _InterruptState() + agent._model_state = {} if metrics is None: metrics = Mock( @@ -2270,6 +2271,7 @@ def test_graph_interrupt_on_agent(agenerator): }, "messages": [], "state": {}, + "model_state": {}, } responses = [ diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 43acd6400..cb0414b42 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -25,6 +25,7 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen agent.messages = [] agent.state = AgentState() # Add state attribute agent._interrupt_state = _InterruptState() # Add interrupt state + agent._model_state = {} # Add model state agent.tool_registry = Mock() agent.tool_registry.registry = {} agent.tool_registry.process_tools = Mock() diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 9b2d84a51..1d5048113 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,6 +5,7 @@ import pytest from strands.agent.agent import Agent +from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager from strands.agent.state import AgentState @@ -245,6 +246,32 @@ def test_initialize_multi_agent_existing(existing_session_manager, mock_multi_ag mock_multi_agent.deserialize_state.assert_called_once_with(existing_state) +def test_initialize_skips_message_restore_for_server_managed_conversation(existing_session_manager): + """Test that messages are not restored when model manages conversation server-side.""" + session_agent = SessionAgent( + agent_id="existing-agent", + state={}, + conversation_manager_state=NullConversationManager().get_state(), + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {}, "activated": False}, + "model_state": {"response_id": "resp_abc123"}, + }, + ) + existing_session_manager.session_repository.create_agent("test-session", session_agent) + + message = SessionMessage.from_message({"role": "user", "content": [{"text": "Hello"}]}, 0) + existing_session_manager.session_repository.create_message("test-session", "existing-agent", message) + + mock_model = Mock() + mock_model.stateful = True + agent = Agent(agent_id="existing-agent", model=mock_model) + existing_session_manager.initialize(agent) + + assert agent.messages == [] + assert agent._model_state == {"response_id": "resp_abc123"} + assert existing_session_manager.session_repository.list_messages("test-session", "existing-agent") == [message] + + def test_fix_broken_tool_use_adds_missing_tool_results(existing_session_manager): """Test that _fix_broken_tool_use adds missing toolResult messages.""" conversation_manager = SlidingWindowConversationManager() @@ -733,6 +760,35 @@ def tracking_update_agent(session_id, session_agent): assert len(update_agent_calls) == 1 +def test_sync_agent_calls_update_when_model_state_changed(mock_repository): + """Test that sync_agent() calls update_agent() when model state changed.""" + session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + + # Create and initialize agent + agent = Agent(agent_id="test-agent", session_manager=session_manager) + + # Track update_agent calls + update_agent_calls = [] + original_update_agent = mock_repository.update_agent + + def tracking_update_agent(session_id, session_agent): + update_agent_calls.append((session_id, session_agent)) + return original_update_agent(session_id, session_agent) + + mock_repository.update_agent = tracking_update_agent + + # First sync to establish baseline + session_manager.sync_agent(agent) + update_agent_calls.clear() + + # Modify model state + agent._model_state["response_id"] = "resp_abc123" + + # Sync should call update_agent because model state changed + session_manager.sync_agent(agent) + assert len(update_agent_calls) == 1 + + def test_sync_agent_tracks_version_after_successful_sync(mock_repository): """Test that sync_agent() tracks version after successful sync.""" session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 3e5360742..b456f2404 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -102,13 +102,17 @@ def test_session_agent_from_agent(): agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) + agent._model_state = {} tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( agent_id="a1", conversation_manager_state={"test": "conversation"}, state={"test": "state"}, - _internal_state={"interrupt_state": {"interrupts": {}, "context": {}, "activated": False}}, + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {}, "activated": False}, + "model_state": {}, + }, created_at=unittest.mock.ANY, updated_at=unittest.mock.ANY, ) @@ -121,7 +125,10 @@ def test_session_agent_initialize_internal_state(): agent_id="a1", conversation_manager_state={}, state={}, - _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, + _internal_state={ + "interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}, + "model_state": {"response_id": "resp_abc"}, + }, ) session_agent.initialize_internal_state(agent) @@ -129,3 +136,7 @@ def test_session_agent_initialize_internal_state(): tru_interrupt_state = agent._interrupt_state exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state + + tru_model_state = agent._model_state + exp_model_state = {"response_id": "resp_abc"} + assert tru_model_state == exp_model_state diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py new file mode 100644 index 000000000..55c445676 --- /dev/null +++ b/tests_integ/models/test_model_mantle.py @@ -0,0 +1,74 @@ +"""Integration tests for OpenAI Responses API on Bedrock Mantle with AWS credentials.""" + +import httpx +import pytest +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from botocore.session import Session as BotocoreSession + +from strands import Agent +from strands.models.openai_responses import OpenAIResponsesModel + + +class _SigV4Auth(httpx.Auth): + """httpx Auth handler that signs requests with AWS SigV4.""" + + def __init__(self, region: str): + session = BotocoreSession() + self.credentials = session.get_credentials().get_frozen_credentials() + self.signer = SigV4Auth(self.credentials, "bedrock", region) + + def auth_flow(self, request: httpx.Request): + aws_request = AWSRequest( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=request.content, + ) + self.signer.add_auth(aws_request) + for key, value in aws_request.headers.items(): + request.headers[key] = value + yield request + + +class _NonClosingAsyncClient(httpx.AsyncClient): + """AsyncClient that survives the OpenAI SDK's context manager lifecycle.""" + + async def aclose(self) -> None: + pass + + +@pytest.fixture +def client_args(): + region = "us-east-1" + return { + "api_key": "unused", + "base_url": f"https://bedrock-mantle.{region}.api.aws/v1", + "http_client": _NonClosingAsyncClient(auth=_SigV4Auth(region)), + } + + +@pytest.fixture +def model(client_args): + return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", client_args=client_args) + + +@pytest.fixture +def stateful_model(client_args): + return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", stateful=True, client_args=client_args) + + +def test_agent_invoke(model): + agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) + result = agent("What is 2+2?") + assert "4" in str(result) + + +def test_responses_server_side_conversation(stateful_model): + agent = Agent(model=stateful_model, system_prompt="Reply in one short sentence.", callback_handler=None) + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + result = agent("What is my name?") + assert "alice" in str(result).lower() diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 6b0b3a95b..042e9e21c 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -287,3 +287,24 @@ def test_system_prompt_backward_compatibility_integration(model): # The response should contain our specific system prompt instruction assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_server_side_conversation(): + """Integration test for server-side conversation state management. + + Verifies that when stateful=True, the model tracks conversation across turns + via previous_response_id and the agent clears messages between invocations. + """ + model = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.") + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + result = agent("What is my name?") + assert "alice" in result.message["content"][0]["text"].lower() diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py index 53d128da6..0d4fe9fe1 100644 --- a/tests_integ/test_session.py +++ b/tests_integ/test_session.py @@ -1,5 +1,6 @@ """Integration tests for session management.""" +import os import tempfile from uuid import uuid4 @@ -9,8 +10,10 @@ from strands import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.models.openai_responses import OpenAIResponsesModel from strands.session.file_session_manager import FileSessionManager from strands.session.s3_session_manager import S3SessionManager +from tests_integ.models.providers import openai as openai_provider # yellow_img imported from conftest @@ -147,3 +150,35 @@ def test_agent_with_s3_session_with_image(yellow_img, bucket_name): finally: session_manager.delete_session(test_session_id) assert session_manager.read_session(test_session_id) is None + + +@openai_provider.mark +def test_agent_with_file_session_server_side_conversation(temp_dir): + """Test that server-side conversation state survives session save/restore.""" + test_session_id = str(uuid4()) + session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + try: + model = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.", session_manager=session_manager) + + agent("My name is Alice.") + assert len(agent.messages) == 0 + + # Simulate process restart: create new session manager and agent + session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) + model_2 = OpenAIResponsesModel( + model_id="gpt-4o-mini", + stateful=True, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent_2 = Agent(model=model_2, system_prompt="Reply in one short sentence.", session_manager=session_manager_2) + + assert len(agent_2.messages) == 0 + result = agent_2("What is my name?") + assert "alice" in result.message["content"][0]["text"].lower() + finally: + session_manager.delete_session(test_session_id) From de9b1498950545816f4384f1aa612e4b0d5d4a2d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 31 Mar 2026 11:10:45 -0400 Subject: [PATCH 202/279] feat: add built-in tool support for OpenAI Responses API (#2011) --- src/strands/models/openai_responses.py | 63 ++++++-- tests/strands/models/test_openai_responses.py | 146 ++++++++++++++++++ tests_integ/models/test_model_openai.py | 92 +++++++++++ 3 files changed, 291 insertions(+), 10 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 01974c11d..19bcba80c 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -1,6 +1,19 @@ """OpenAI model provider using the Responses API. -Note: Built-in tools (web search, code interpreter, file search) are not yet supported. +Built-in tools (e.g. web_search, file_search, code_interpreter) can be passed via the +``params`` configuration and will be merged with any agent function tools in the request. + +All built-in tools produce text responses that stream correctly. Limitations on tool-specific +metadata: + +- web_search (supported): Full support including URL citations. +- file_search (partial): File citation annotations not emitted (no matching CitationLocation variant). +- code_interpreter (partial): Executed code and stdout/stderr not surfaced. +- mcp (partial): Approval flow and ``mcp_list_tools``/``mcp_call`` events not surfaced. +- shell (partial): Local (client-executed) mode not supported. +- tool_search (not supported): Requires ``defer_loading`` on function tools, which is not supported. +- image_generation (not supported): Requires image content block delta support in the event loop. +- computer_use_preview (not supported): Requires a developer-managed screenshot/action loop. Docs: https://platform.openai.com/docs/api-reference/responses """ @@ -40,6 +53,7 @@ import openai # noqa: E402 - must import after version check +from ..types.citations import WebLocationDict # noqa: E402 from ..types.content import ContentBlock, Messages, Role # noqa: E402 from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 @@ -103,12 +117,7 @@ def responses(self) -> Any: class OpenAIResponsesModel(Model): - """OpenAI Responses API model provider implementation. - - Note: - This implementation currently only supports function tools (custom tools defined via tool_specs). - OpenAI's built-in system tools are not yet supported. - """ + """OpenAI Responses API model provider implementation.""" client: Client client_args: dict[str, Any] @@ -255,6 +264,22 @@ async def stream( {"chunk_type": "content_delta", "data_type": "text", "data": event.delta} ) + elif event.type == "response.output_text.annotation.added": + if hasattr(event, "annotation"): + if event.annotation.get("type") == "url_citation": + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "citation", + "data": event.annotation, + } + ) + else: + logger.warning( + "annotation_type=<%s> | unsupported annotation type", + event.annotation.get("type"), + ) + elif event.type == "response.output_item.added": # Tool call started if ( @@ -431,7 +456,8 @@ def _format_request( # Add tools if provided if tool_specs: - request["tools"] = [ + # Merge with any built-in tools (e.g. web_search) already in the request from params + request.setdefault("tools", []).extend( { "type": "function", "name": tool_spec["name"], @@ -439,8 +465,7 @@ def _format_request( "parameters": tool_spec["inputSchema"]["json"], } for tool_spec in tool_specs - ] - # Add tool_choice if provided + ) request.update(self._format_request_tool_choice(tool_choice)) return request @@ -550,6 +575,11 @@ def _format_request_message_content(cls, content: ContentBlock, *, role: Role = text_type = "output_text" if role == "assistant" else "input_text" return {"type": text_type, "text": content["text"]} + if "citationsContent" in content: + text = "".join(c["text"] for c in content["citationsContent"].get("content", []) if "text" in c) + text_type = "output_text" if role == "assistant" else "input_text" + return {"type": text_type, "text": text} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod @@ -680,6 +710,19 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: if event["data_type"] == "reasoning_content": return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + if event["data_type"] == "citation": + web_location: WebLocationDict = {"web": {"url": event["data"].get("url", "")}} + return { + "contentBlockDelta": { + "delta": { + "citation": { + "title": event["data"].get("title", ""), + "location": web_location, + } + } + } + } + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} case "content_stop": diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 6a16eef89..db4c4b1e1 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -394,6 +394,19 @@ def test_format_request(model, messages, tool_specs, system_prompt): {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "I'm thinking"}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "I'm thinking"}}}}, ), + # Content Delta - Citation + ( + { + "chunk_type": "content_delta", + "data_type": "citation", + "data": {"type": "url_citation", "title": "Example", "url": "https://example.com"}, + }, + { + "contentBlockDelta": { + "delta": {"citation": {"title": "Example", "location": {"web": {"url": "https://example.com"}}}} + } + }, + ), # Content Delta - Text ( {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, @@ -618,6 +631,74 @@ async def test_stream_reasoning_content(openai_client, model, agenerator, alist) assert len(content_stops) == 2 +@pytest.mark.asyncio +async def test_stream_citation_annotations(openai_client, model, agenerator, alist): + """Test that web search citation annotations are streamed as CitationsDelta events.""" + mock_text_event1 = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is here. ") + mock_text_event2 = unittest.mock.Mock(type="response.output_text.delta", delta="(example.com)") + mock_annotation_event = unittest.mock.Mock( + type="response.output_text.annotation.added", + annotation={ + "type": "url_citation", + "title": "Example Source", + "url": "https://example.com/article", + }, + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event1, mock_text_event2, mock_annotation_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "search something"}]}] + tru_events = await alist(model.stream(messages)) + + citation_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "citation" in e["contentBlockDelta"]["delta"] + ] + assert len(citation_deltas) == 1 + assert citation_deltas[0] == { + "contentBlockDelta": { + "delta": { + "citation": { + "title": "Example Source", + "location": {"web": {"url": "https://example.com/article"}}, + } + } + } + } + + +@pytest.mark.asyncio +async def test_stream_unsupported_annotation_type(openai_client, model, agenerator, alist, caplog): + """Test that unsupported annotation types log a warning and are not emitted.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Some text") + mock_annotation_event = unittest.mock.Mock( + type="response.output_text.annotation.added", + annotation={"type": "file_citation", "file_id": "file-123", "filename": "doc.pdf"}, + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_annotation_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "search files"}]}] + tru_events = await alist(model.stream(messages)) + + citation_deltas = [ + e for e in tru_events if "contentBlockDelta" in e and "citation" in e["contentBlockDelta"]["delta"] + ] + assert len(citation_deltas) == 0 + assert "annotation_type= | unsupported annotation type" in caplog.text + + @pytest.mark.asyncio async def test_structured_output(openai_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] @@ -886,6 +967,71 @@ def test_format_request_with_tool_choice(model, messages, tool_specs): assert request["tool_choice"] == {"type": "function", "name": "test_tool"} +def test_format_request_merges_builtin_tools_with_function_tools(messages, tool_specs): + """Test that built-in tools from params are merged with function tools.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + ) + request = model._format_request(messages, tool_specs) + + assert request["tools"] == [ + {"type": "web_search"}, + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": {"input": {"type": "string"}}, + "required": ["input"], + }, + }, + ] + + +def test_format_request_builtin_tools_without_function_tools(messages): + """Test that built-in tools from params are preserved when no function tools are provided.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + ) + request = model._format_request(messages) + + assert request["tools"] == [{"type": "web_search"}] + + +def test_format_request_messages_with_citations_content(): + """Test that citationsContent blocks are converted to text in the request.""" + messages = [ + {"role": "user", "content": [{"text": "search something"}]}, + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "title": "Example", + "location": {"web": {"url": "https://example.com", "domain": "example.com"}}, + "sourceContent": [{"text": "cited text"}], + } + ], + "content": [{"text": "The answer with citations."}], + } + } + ], + }, + ] + formatted = OpenAIResponsesModel._format_request_messages(messages) + + assistant_msg = [m for m in formatted if m.get("role") == "assistant"][0] + assert assistant_msg == { + "role": "assistant", + "content": [{"type": "output_text", "text": "The answer with citations."}], + } + + def test_format_request_message_content_image_size_limit(): """Test that oversized images raise ValueError.""" oversized_data = b"x" * (_MAX_MEDIA_SIZE_BYTES + 1) diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 042e9e21c..5a2d21570 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -1,5 +1,8 @@ import os +import tempfile +import time +import openai as openai_sdk import pydantic import pytest @@ -80,6 +83,31 @@ def lower(_, value): return Color(name="yellow") +@pytest.fixture(scope="module") +def openai_vector_store(): + """Create a vector store with a test file for file_search tests.""" + client = openai_sdk.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f: + f.write("The secret code is ALPHA-7742.") + f.flush() + file_obj = client.files.create(file=open(f.name, "rb"), purpose="assistants") + + vector_store = client.vector_stores.create(name="test-builtin-tools") + try: + client.vector_stores.files.create(vector_store_id=vector_store.id, file_id=file_obj.id) + + for _ in range(30): + if client.vector_stores.retrieve(vector_store.id).file_counts.completed > 0: + break + time.sleep(1) + + yield vector_store.id + finally: + client.vector_stores.delete(vector_store.id) + client.files.delete(file_obj.id) + + @pytest.fixture(scope="module") def test_image_path(request): return request.config.rootpath / "tests_integ" / "test_image.png" @@ -308,3 +336,67 @@ def test_responses_server_side_conversation(): result = agent("What is my name?") assert "alice" in result.message["content"][0]["text"].lower() + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_web_search(): + """Test that web_search produces text with citation content.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "web_search"}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + result = agent("Search https://strandsagents.com/ and tell me what Strands Agents is.") + content = result.message["content"][0] + + assert "citationsContent" in content + citations = content["citationsContent"]["citations"] + assert any("strandsagents.com" in c["location"]["web"]["url"] for c in citations) + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_file_search(openai_vector_store): + """Test that file_search produces text output from uploaded files.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "file_search", "vector_store_ids": [openai_vector_store]}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer based on the files.", callback_handler=None) + + result = agent("What is the secret code?") + text = result.message["content"][0]["text"] + assert "ALPHA-7742" in text + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_code_interpreter(): + """Test that code_interpreter produces correct results via text output.""" + model = OpenAIResponsesModel( + model_id="gpt-4o", + params={"tools": [{"type": "code_interpreter", "container": {"type": "auto"}}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + # SHA-256 of "strands" requires actual computation + result = agent("Compute the SHA-256 hash of the string 'strands'. Return only the hex digest.") + text = result.message["content"][0]["text"] + assert "11e0e34bd35e12185cfacd5e5a256ab4292bfa3616d8d5b74e20eca36feed228" in text + + +@pytest.mark.skipif(not _openai_responses_available, reason="OpenAI Responses API not available") +def test_responses_builtin_tool_shell(): + """Test that the shell built-in tool executes commands in a hosted container.""" + model = OpenAIResponsesModel( + model_id="gpt-5.4-mini", + params={"tools": [{"type": "shell", "environment": {"type": "container_auto"}}]}, + client_args={"api_key": os.getenv("OPENAI_API_KEY")}, + ) + agent = Agent(model=model, system_prompt="Answer concisely.", callback_handler=None) + + result = agent("Use the shell to compute the md5sum of the string 'strands-test'. Return only the hash.") + text = result.message["content"][0]["text"] + assert "d82f373f079b00a1db7ef1eec7f15c68" in text From e267a64b8a664bbda5802005281551e5dd6af412 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 31 Mar 2026 12:50:05 -0400 Subject: [PATCH 203/279] fix: handle reasoning content in OpenAIResponsesModel request formatting (#2013) --- src/strands/models/openai_responses.py | 16 +++++-- tests/strands/models/test_openai_responses.py | 44 +++++++++++++++++-- tests_integ/models/test_model_mantle.py | 25 +++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 19bcba80c..30e4e2fa1 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -240,8 +240,13 @@ async def stream( if model_state is not None and response_id: model_state["response_id"] = response_id - elif event.type == "response.reasoning_text.delta": - # Reasoning content streaming (for o1/o3 reasoning models) + elif event.type in ( + "response.reasoning_text.delta", + "response.reasoning_summary_text.delta", + ): + # Reasoning content streaming: + # - reasoning_text: full chain-of-thought (gpt-oss models) + # - reasoning_summary_text: condensed summary (o-series models) chunks, data_type = self._stream_switch_content("reasoning_content", data_type) for chunk in chunks: yield chunk @@ -510,10 +515,15 @@ def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: role = message["role"] contents = message["content"] + if any("reasoningContent" in content for content in contents): + logger.warning( + "reasoningContent is not yet supported in multi-turn conversations with the Responses API" + ) + formatted_contents = [ cls._format_request_message_content(content, role=role) for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"]) ] formatted_tool_calls = [ diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index db4c4b1e1..ef31cc1e6 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -596,9 +596,16 @@ async def test_stream_response_incomplete(openai_client, model, agenerator, alis @pytest.mark.asyncio -async def test_stream_reasoning_content(openai_client, model, agenerator, alist): - """Test that reasoning content (o1/o3 models) is streamed correctly.""" - mock_reasoning_event = unittest.mock.Mock(type="response.reasoning_text.delta", delta="Let me think...") +@pytest.mark.parametrize( + "event_type", + [ + "response.reasoning_text.delta", + "response.reasoning_summary_text.delta", + ], +) +async def test_stream_reasoning_content(openai_client, model, agenerator, alist, event_type): + """Test that reasoning content is streamed correctly for both full and summary reasoning events.""" + mock_reasoning_event = unittest.mock.Mock(type=event_type, delta="Let me think...") mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="The answer is 42") mock_complete_event = unittest.mock.Mock( type="response.completed", @@ -1152,3 +1159,34 @@ async def test_stream_stateful(openai_client, model_id, agenerator, alist): "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, "metrics": {"latencyMs": 0}, } + + +def test_format_request_messages_excludes_reasoning_content(caplog): + """Test that reasoningContent blocks are filtered from messages with a warning.""" + messages = [ + { + "content": [{"text": "Hello"}], + "role": "user", + }, + { + "content": [ + {"reasoningContent": {"reasoningText": {"text": "Let me think..."}}}, + {"text": "The answer is 42"}, + ], + "role": "assistant", + }, + { + "content": [{"text": "Thanks"}], + "role": "user", + }, + ] + + with caplog.at_level("WARNING"): + result = OpenAIResponsesModel._format_request_messages(messages) + + assert result == [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "The answer is 42"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Thanks"}]}, + ] + assert "reasoningContent is not yet supported" in caplog.text diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py index 55c445676..b0482acd1 100644 --- a/tests_integ/models/test_model_mantle.py +++ b/tests_integ/models/test_model_mantle.py @@ -72,3 +72,28 @@ def test_responses_server_side_conversation(stateful_model): result = agent("What is my name?") assert "alice" in str(result).lower() + + +def test_reasoning_content_multi_turn(client_args): + """Test that reasoning content from gpt-oss models doesn't break multi-turn conversations.""" + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args=client_args, + params={"reasoning": {"effort": "low"}}, + ) + agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) + + result1 = agent("What is 2+2?") + assert "4" in str(result1) + + # Verify reasoning content was produced + has_reasoning = any( + "reasoningContent" in block + for msg in agent.messages + if msg["role"] == "assistant" + for block in msg["content"] + ) + assert has_reasoning + + # Second turn should not raise despite reasoningContent in message history + agent("What about 3+3?") From 7b4df8aef611ef49cd8e87557c8002b1ff2f6555 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 31 Mar 2026 15:15:11 -0400 Subject: [PATCH 204/279] fix: fix type imcompatible (#2018) --- src/strands/telemetry/tracer.py | 4 +++- tests_integ/models/test_model_mantle.py | 5 +---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index c03d9d962..85083722e 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -839,7 +839,9 @@ def _add_system_prompt_event( if system_prompt is None and system_prompt_content is None: return - content_blocks = system_prompt_content if system_prompt_content else [{"text": system_prompt}] + content_blocks: list[ContentBlock] = ( + system_prompt_content if system_prompt_content else [{"text": system_prompt or ""}] + ) if self.use_latest_genai_conventions: parts = self._map_content_blocks_to_otel_parts(content_blocks) diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py index b0482acd1..1dc029344 100644 --- a/tests_integ/models/test_model_mantle.py +++ b/tests_integ/models/test_model_mantle.py @@ -88,10 +88,7 @@ def test_reasoning_content_multi_turn(client_args): # Verify reasoning content was produced has_reasoning = any( - "reasoningContent" in block - for msg in agent.messages - if msg["role"] == "assistant" - for block in msg["content"] + "reasoningContent" in block for msg in agent.messages if msg["role"] == "assistant" for block in msg["content"] ) assert has_reasoning From 53917945a137c5e61e0b9b62ed079c90280bf014 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Tue, 31 Mar 2026 16:58:44 -0400 Subject: [PATCH 205/279] fix: isolate langfuse env vars (#2022) --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f2a8909cb..1c0083e85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,9 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", raising=False) monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) + monkeypatch.delenv("LANGFUSE_BASE_URL", raising=False) @pytest.fixture From cda2a55a8eb0c5e681bfda9764fd47048a67f6b9 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:27:03 -0400 Subject: [PATCH 206/279] fix: restore explicit span.end() to fix span end_time regression (#2032) Co-authored-by: Di-Is Co-authored-by: Mackenzie Zastrow --- src/strands/event_loop/event_loop.py | 132 ++++++++++-------- src/strands/telemetry/tracer.py | 34 ++--- tests/strands/event_loop/test_event_loop.py | 60 ++++++++ .../test_event_loop_structured_output.py | 85 +++++++++++ tests/strands/telemetry/test_tracer.py | 34 +++++ 5 files changed, 264 insertions(+), 81 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 374cfe129..b4af16058 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -139,25 +139,29 @@ async def event_loop_cycle( ) invocation_state["event_loop_cycle_span"] = cycle_span - with trace_api.use_span(cycle_span, end_on_exit=True): - # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. - if agent._interrupt_state.activated: - stop_reason: StopReason = "tool_use" - message = agent._interrupt_state.context["tool_use_message"] - # Skip model invocation if the latest message contains ToolUse - elif _has_tool_use_in_latest_message(agent.messages): - stop_reason = "tool_use" - message = agent.messages[-1] - else: - model_events = _handle_model_execution( - agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context - ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + with trace_api.use_span(cycle_span, end_on_exit=False): + try: + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] + # Skip model invocation if the latest message contains ToolUse + elif _has_tool_use_in_latest_message(agent.messages): + stop_reason = "tool_use" + message = agent.messages[-1] + else: + model_events = _handle_model_execution( + agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context + ) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -196,42 +200,45 @@ async def event_loop_cycle( # End the cycle and return results agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - # Set attributes before span auto-closes + + # Force structured output tool call if LLM didn't use it automatically + if structured_output_context.is_enabled and stop_reason == "end_turn": + if structured_output_context.force_attempted: + raise StructuredOutputException( + "The model failed to invoke the structured output tool even after it was forced." + ) + structured_output_context.set_forced_mode() + logger.debug("Forcing structured output tool") + await agent._append_messages( + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} + ) + + tracer.end_event_loop_cycle_span(cycle_span, message) + events = recurse_event_loop( + agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context + ) + async for typed_event in events: + yield typed_event + return + tracer.end_event_loop_cycle_span(cycle_span, message) - except EventLoopException: - # Don't yield or log the exception - we already did it when we - # raised the exception and we don't need that duplication. + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + except ( + StructuredOutputException, + EventLoopException, + ContextWindowOverflowException, + MaxTokensReachedException, + ) as e: + # These exceptions should bubble up directly rather than get wrapped in an EventLoopException + tracer.end_span_with_error(cycle_span, str(e), e) raise - except (ContextWindowOverflowException, MaxTokensReachedException) as e: - # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException - raise e except Exception as e: + tracer.end_span_with_error(cycle_span, str(e), e) # Handle any other exceptions yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - # Force structured output tool call if LLM didn't use it automatically - if structured_output_context.is_enabled and stop_reason == "end_turn": - if structured_output_context.force_attempted: - raise StructuredOutputException( - "The model failed to invoke the structured output tool even after it was forced." - ) - structured_output_context.set_forced_mode() - logger.debug("Forcing structured output tool") - await agent._append_messages( - {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} - ) - - events = recurse_event_loop( - agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context - ) - async for typed_event in events: - yield typed_event - return - - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - async def recurse_event_loop( agent: "Agent", @@ -316,20 +323,21 @@ async def _handle_model_execution( system_prompt=agent.system_prompt, system_prompt_content=agent._system_prompt_content, ) - with trace_api.use_span(model_invoke_span, end_on_exit=True): - await agent.hooks.invoke_callbacks_async( - BeforeModelCallEvent( - agent=agent, - invocation_state=invocation_state, + with trace_api.use_span(model_invoke_span, end_on_exit=False): + try: + await agent.hooks.invoke_callbacks_async( + BeforeModelCallEvent( + agent=agent, + invocation_state=invocation_state, + ) ) - ) - if structured_output_context.forced_mode: - tool_spec = structured_output_context.get_tool_spec() - tool_specs = [tool_spec] if tool_spec else [] - else: - tool_specs = agent.tool_registry.get_all_tool_specs() - try: + if structured_output_context.forced_mode: + tool_spec = structured_output_context.get_tool_spec() + tool_specs = [tool_spec] if tool_spec else [] + else: + tool_specs = agent.tool_registry.get_all_tool_specs() + async for event in stream_messages( agent.model, agent.system_prompt, @@ -363,17 +371,17 @@ async def _handle_model_execution( "stop_reason=<%s>, retry_requested= | hook requested model retry", stop_reason, ) + tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) continue # Retry the model call if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) - # Set attributes before span auto-closes tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason) break # Success! Break out of retry loop except Exception as e: - # Exception is automatically recorded by use_span with end_on_exit=True + tracer.end_span_with_error(model_invoke_span, str(e), e) after_model_call_event = AfterModelCallEvent( agent=agent, invocation_state=invocation_state, @@ -541,7 +549,7 @@ async def _handle_tool_execution( interrupts, structured_output=structured_output_result, ) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message) @@ -559,7 +567,7 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) - # Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle) + # End the cycle span before yielding the recursive cycle. if cycle_span: tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 85083722e..19a163f5c 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -185,6 +185,7 @@ def _end_span( span: Span, attributes: dict[str, AttributeValue] | None = None, error: Exception | None = None, + error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -192,8 +193,9 @@ def _end_span( span: The span to end attributes: Optional attributes to set before ending the span error: Optional exception if an error occurred + error_message: Optional error message to set in the span status """ - if not span: + if not span or not span.is_recording(): return try: @@ -206,7 +208,8 @@ def _end_span( # Handle error if present if error: - span.set_status(StatusCode.ERROR, str(error)) + status_description = error_message or str(error) or type(error).__name__ + span.set_status(StatusCode.ERROR, status_description) span.record_exception(error) else: span.set_status(StatusCode.OK) @@ -229,11 +232,11 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Excepti error_message: Error message to set in the span status. exception: Optional exception to record in the span. """ - if not span: + if not span or not span.is_recording(): return error = exception or Exception(error_message) - self._end_span(span, error=error) + self._end_span(span, error=error, error_message=error_message) def _add_event( self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False @@ -330,18 +333,15 @@ def end_model_invoke_span( ) -> None: """End a model invocation span with results and metrics. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from the model. usage: Token usage information from the model call. metrics: Metrics from the model call. stop_reason: The reason the model stopped generating. """ - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) + if not span or not span.is_recording(): + return attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], @@ -378,7 +378,7 @@ def end_model_invoke_span( event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, ) - span.set_attributes(attributes) + self._end_span(span, attributes) def start_tool_call_span( self, @@ -553,20 +553,14 @@ def end_event_loop_cycle_span( ) -> None: """End an event loop cycle span with results. - Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes. - Status in the span is automatically set to UNSET (OK) on success or ERROR on exception. - Args: - span: The span to set attributes on. + span: The span to end. message: The message response from this cycle. tool_result_message: Optional tool result message if a tool was called. """ - if not span: + if not span or not span.is_recording(): return - # Set end time attribute - span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat()) - event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: @@ -591,6 +585,8 @@ def end_event_loop_cycle_span( else: self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) + self._end_span(span) + def start_agent_span( self, messages: Messages, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2f57f5560..f91f7c2af 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -5,6 +5,9 @@ from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import strands import strands.telemetry @@ -19,6 +22,7 @@ ) from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry from strands.types._events import EventLoopStopEvent @@ -583,6 +587,14 @@ async def test_event_loop_tracing_with_model_error( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 2 + mock_tracer.end_span_with_error.assert_has_calls( + [ + call(model_span, "Input too long", model.stream.side_effect), + call(cycle_span, "Input too long", model.stream.side_effect), + ] + ) + @pytest.mark.asyncio async def test_event_loop_cycle_max_tokens_exception( @@ -673,6 +685,53 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( + agent, + model, + tool_stream, + agenerator, + alist, +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + async def delayed_text_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(0.05) + yield {"contentBlockStop": {}} + + agent.trace_span = None + agent._system_prompt_content = None + model.config = {"model_id": "test-model"} + model.stream.side_effect = [ + agenerator(tool_stream), + delayed_text_stream(), + ] + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + provider.force_flush() + cycle_spans = sorted( + [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"], + key=lambda span: span.start_time, + ) + + assert len(cycle_spans) == 2 + assert cycle_spans[0].end_time <= cycle_spans[1].start_time + assert cycle_spans[0].end_time < cycle_spans[1].end_time + + @patch("strands.event_loop.event_loop.get_tracer") @pytest.mark.asyncio async def test_event_loop_tracing_with_throttling_exception( @@ -709,6 +768,7 @@ async def test_event_loop_tracing_with_throttling_exception( ) await alist(stream) + assert mock_tracer.end_span_with_error.call_count == 1 # Verify span was created for the successful retry assert mock_tracer.start_model_invoke_span.call_count == 2 assert mock_tracer.end_model_invoke_span.call_count == 1 diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index ad792f52c..2d1150712 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -4,16 +4,22 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode from pydantic import BaseModel from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics +from strands.telemetry.tracer import Tracer from strands.tools.registry import ToolRegistry from strands.tools.structured_output._structured_output_context import ( DEFAULT_STRUCTURED_OUTPUT_PROMPT, StructuredOutputContext, ) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent +from strands.types.exceptions import EventLoopException, StructuredOutputException class UserModel(BaseModel): @@ -253,6 +259,85 @@ async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent assert args["content"][0]["text"] == custom_prompt +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_structured_output_failure_closes_cycle_span_with_error( + mock_get_tracer, + mock_agent, + structured_output_context, + agenerator, + alist, +): + mock_tracer = Mock() + cycle_span = Mock() + model_span = Mock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + mock_tracer.start_model_invoke_span.return_value = model_span + mock_get_tracer.return_value = mock_tracer + + structured_output_context.set_forced_mode() + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Still not structured"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + expected_message = "The model failed to invoke the structured output tool even after it was forced." + with pytest.raises(StructuredOutputException, match=expected_message): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + mock_tracer.end_model_invoke_span.assert_called_once() + mock_tracer.end_event_loop_cycle_span.assert_not_called() + mock_tracer.end_span_with_error.assert_called_once() + assert mock_tracer.end_span_with_error.call_args.args[0] == cycle_span + assert mock_tracer.end_span_with_error.call_args.args[1] == expected_message + assert isinstance(mock_tracer.end_span_with_error.call_args.args[2], StructuredOutputException) + + +@pytest.mark.asyncio +async def test_event_loop_forced_structured_output_append_failure_records_error_span( + mock_agent, structured_output_context, agenerator, alist +): + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + tracer = Tracer() + tracer.tracer_provider = provider + tracer.tracer = provider.get_tracer(tracer.service_name) + + mock_agent.model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + mock_agent._append_messages = AsyncMock(side_effect=RuntimeError("append failed")) + + with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer): + with pytest.raises(EventLoopException, match="append failed"): + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + finished_cycle_spans = [span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"] + + assert len(finished_cycle_spans) == 1 + assert finished_cycle_spans[0].status.status_code == StatusCode.ERROR + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 9176ce4ae..bcd42b610 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -128,6 +128,30 @@ def test_end_span_with_error_message(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): + """Test that empty exception messages fall back to the exception type name.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_span_with_error_prefers_explicit_message(mock_span): + """Test that an explicit error message takes precedence over the exception text.""" + tracer = Tracer() + error = Exception() + + tracer.end_span_with_error(mock_span, "Explicit error message", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Explicit error message") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_start_model_invoke_span(mock_tracer): """Test starting a model invoke span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -321,6 +345,8 @@ def test_end_model_invoke_span(mock_span): "gen_ai.choice", attributes={"message": json.dumps(message["content"]), "finish_reason": "end_turn"}, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): @@ -360,6 +386,8 @@ def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): ), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_tool_call_span(mock_tracer): @@ -760,6 +788,8 @@ def test_end_event_loop_cycle_span(mock_span): "tool.result": json.dumps(tool_result_message["content"]), }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): @@ -795,6 +825,8 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): ) }, ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_start_agent_span(mock_tracer): @@ -1028,6 +1060,8 @@ def test_end_model_invoke_span_with_cache_metrics(mock_span): "gen_ai.server.time_to_first_token": 5, } ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() def test_end_agent_span_with_cache_metrics(mock_span): From 635edbc60b72c5aa3df7e92bdfd8d1aec7b5223c Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:41:03 -0400 Subject: [PATCH 207/279] feat(context): surface context window size from LLM response metadata Add latest_context_size to EventLoopMetrics and context_size to AgentResult, exposing the inputTokens count from the most recent LLM call as a measure of current context window usage. --- src/strands/agent/agent_result.py | 9 +++++ src/strands/telemetry/metrics.py | 13 +++++++ tests/strands/agent/test_agent_result.py | 14 +++++++ tests/strands/telemetry/test_metrics.py | 47 ++++++++++++++++++++++++ 4 files changed, 83 insertions(+) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index 63b7a0d4a..f0a399f81 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -35,6 +35,15 @@ class AgentResult: interrupts: Sequence[Interrupt] | None = None structured_output: BaseModel | None = None + @property + def context_size(self) -> int | None: + """Most recent context size in tokens from the last LLM call. + + Returns: + The input token count from the most recent cycle, or None if no data is available. + """ + return self.metrics.latest_context_size + def __str__(self) -> str: """Return a string representation of the agent result. diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index 163df803a..dae05965e 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -202,6 +202,19 @@ class EventLoopMetrics: accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) + @property + def latest_context_size(self) -> int | None: + """Most recent context size from the last LLM call. + + This represents the current context size as reported by the model. + + Returns: + The input token count from the most recent cycle, or None if no data is available. + """ + if self.agent_invocations and self.agent_invocations[-1].cycles: + return self.agent_invocations[-1].cycles[-1].usage.get("inputTokens") + return None + @property def _metrics_client(self) -> "MetricsClient": """Get the singleton MetricsClient instance.""" diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index a4478c3ca..64391f299 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -370,3 +370,17 @@ def test__str__empty_interrupts_returns_agent_message(mock_metrics, simple_messa # Empty list is falsy, should fall through to text content assert message_string == "Hello world!\n" + + +def test_context_size_delegates_to_metrics(mock_metrics, simple_message: Message): + """Test that context_size delegates to metrics.latest_context_size.""" + mock_metrics.latest_context_size = 12345 + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.context_size == 12345 + + +def test_context_size_none_when_no_data(mock_metrics, simple_message: Message): + """Test that context_size returns None when metrics has no data.""" + mock_metrics.latest_context_size = None + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.context_size is None diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index 800bcebc4..c38fa6a18 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -566,3 +566,50 @@ def test_reset_usage_metrics(usage, event_loop_metrics, mock_get_meter_provider) # Verify accumulated_usage is NOT cleared assert event_loop_metrics.accumulated_usage["inputTokens"] == 11 + + +def test_latest_context_size_no_invocations(event_loop_metrics): + assert event_loop_metrics.latest_context_size is None + + +def test_latest_context_size_invocation_with_no_cycles(event_loop_metrics): + event_loop_metrics.reset_usage_metrics() + assert event_loop_metrics.latest_context_size is None + + +def test_latest_context_size_returns_last_cycle(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=250, outputTokens=80, totalTokens=330)) + + assert event_loop_metrics.latest_context_size == 250 + + +def test_latest_context_size_returns_from_latest_invocation(event_loop_metrics, mock_get_meter_provider): + # First invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + # Second invocation + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=500, outputTokens=80, totalTokens=580)) + + assert event_loop_metrics.latest_context_size == 500 + + +def test_latest_context_size_missing_input_tokens_key(event_loop_metrics): + """Returns None when usage dict is missing inputTokens (e.g. provider bug).""" + event_loop_metrics.reset_usage_metrics() + invocation = event_loop_metrics.agent_invocations[-1] + invocation.cycles.append( + strands.telemetry.metrics.EventLoopCycleMetric( + event_loop_cycle_id="c1", + usage={"outputTokens": 50, "totalTokens": 50}, + ) + ) + assert event_loop_metrics.latest_context_size is None From 94fc8dd3ac6910e3608713b7cc40f7cfa65f31e8 Mon Sep 17 00:00:00 2001 From: BHUKYA VENKATESH Date: Fri, 3 Apr 2026 02:31:02 +0530 Subject: [PATCH 208/279] feat: add service_tier support to BedrockModel (#1799) Co-authored-by: BV-Venky --- src/strands/models/bedrock.py | 6 ++++++ tests/strands/models/test_bedrock.py | 14 ++++++++++++++ tests_integ/models/test_model_bedrock.py | 21 +++++++++++++++++++++ 3 files changed, 41 insertions(+) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 5de34a6c2..5b7a2f34e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -93,6 +93,10 @@ class BedrockConfig(TypedDict, total=False): model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") include_tool_result_status: Flag to include status field in tool results. True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". + service_tier: Service tier for the request, controlling the trade-off between latency and cost. + Valid values: "default" (standard), "priority" (faster, premium), "flex" (cheaper, slower). + Please check https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html for + supported service tiers, models, and regions stop_sequences: List of sequences that will stop generation when encountered streaming: Flag to enable/disable streaming. Defaults to True. temperature: Controls randomness in generation (higher = more random) @@ -117,6 +121,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: int | None model_id: str include_tool_result_status: Literal["auto"] | bool | None + service_tier: str | None stop_sequences: list[str] | None streaming: bool | None temperature: float | None @@ -245,6 +250,7 @@ def _format_request( "modelId": self.config["model_id"], "messages": self._format_bedrock_messages(messages), "system": system_blocks, + **({"serviceTier": {"type": self.config["service_tier"]}} if self.config.get("service_tier") else {}), **( { "toolConfig": { diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5f81efd24..9c565d4f4 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -379,6 +379,20 @@ def test_format_request_guardrail_config_without_trace_or_stream_processing_mode assert tru_request == exp_request +def test_format_request_with_service_tier(model, messages, model_id): + model.update_config(service_tier="flex") + tru_request = model._format_request(messages) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "serviceTier": {"type": "flex"}, + "system": [], + } + + assert tru_request == exp_request + + def test_format_request_inference_config(model, messages, model_id, inference_config): model.update_config(**inference_config) tru_request = model._format_request(messages) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 0b3aa7b47..e4ef727ce 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -73,6 +73,27 @@ def test_non_streaming_agent(non_streaming_agent): assert len(str(result)) > 0 +def test_bedrock_service_tier_flex_invocation_succeeds(): + """Bedrock accepts serviceTier when model and region support Priority/Flex tiers. + + Tier support is model- and region-specific. See: + https://docs.aws.amazon.com/bedrock/latest/userguide/service-tiers-inference.html + + CI runs integ tests with AWS_REGION=us-east-1; amazon.nova-pro-v1:0 is listed for + that region under Priority and Flex tiers. + """ + model = BedrockModel( + model_id="amazon.nova-pro-v1:0", + region_name="us-east-1", + service_tier="flex", + ) + agent = Agent(model=model, load_tools_from_directory=False) + result = agent("Reply with exactly the word: ok") + + assert result.stop_reason == "end_turn" + assert len(str(result).strip()) > 0 + + @pytest.mark.asyncio async def test_streaming_model_events(streaming_model, alist): """Test streaming model events.""" From 1682a0cd0e8c416baf2c541581d46e3c6f0d7161 Mon Sep 17 00:00:00 2001 From: Manan Patel <43514923+mananpatel320@users.noreply.github.com> Date: Tue, 7 Apr 2026 00:44:05 +0530 Subject: [PATCH 209/279] =?UTF-8?q?fix:=20forward=20=5Fmeta=20to=20MCP=20t?= =?UTF-8?q?ool=20calls=20and=20fix=20model=5Fdump=20alias=20seriali?= =?UTF-8?q?=E2=80=A6=20(#1918)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Manan Patel Co-authored-by: Nicholas Clegg --- src/strands/tools/mcp/mcp_client.py | 12 +++- src/strands/tools/mcp/mcp_instrumentation.py | 5 +- tests/strands/tools/mcp/test_mcp_client.py | 63 ++++++++++++++--- .../tools/mcp/test_mcp_instrumentation.py | 30 +++++++- tests_integ/mcp/echo_server.py | 8 +++ tests_integ/mcp/test_mcp_client.py | 68 +++++++++++++++++++ 6 files changed, 170 insertions(+), 16 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1fd2990ec..7574f4b65 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -568,6 +568,7 @@ def _create_call_tool_coroutine( name: str, arguments: dict[str, Any] | None, read_timeout_seconds: timedelta | None, + meta: dict[str, Any] | None = None, ) -> Coroutine[Any, Any, MCPCallToolResult]: """Create the appropriate coroutine for calling a tool. @@ -578,6 +579,7 @@ def _create_call_tool_coroutine( name: Name of the tool to call. arguments: Optional arguments to pass to the tool. read_timeout_seconds: Optional timeout for the tool call. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). Returns: A coroutine that will execute the tool call. @@ -598,7 +600,7 @@ async def _call_as_task() -> MCPCallToolResult: async def _call_tool_direct() -> MCPCallToolResult: return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds + name, arguments, read_timeout_seconds, meta=meta ) return _call_tool_direct() @@ -609,6 +611,7 @@ def call_tool_sync( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. @@ -620,6 +623,7 @@ def call_tool_sync( name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -629,7 +633,7 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -642,6 +646,7 @@ async def call_tool_async( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. @@ -653,6 +658,7 @@ async def call_tool_async( name: Name of the tool to call arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call + meta: Optional metadata to pass to the tool call per MCP spec (_meta) Returns: MCPToolResult: The result of the tool call @@ -662,7 +668,7 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index d1750daa3..5e64cc3d5 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -90,9 +90,10 @@ def patch_mcp_client(wrapped: Callable[..., Any], instance: Any, args: Any, kwar if hasattr(request.root, "params") and request.root.params: # Handle Pydantic models if hasattr(request.root.params, "model_dump") and hasattr(request.root.params, "model_validate"): - params_dict = request.root.params.model_dump() + params_dict = request.root.params.model_dump(by_alias=True) # Add _meta with tracing context - meta = params_dict.setdefault("_meta", {}) + meta = params_dict.get("_meta") if params_dict.get("_meta") is not None else {} + params_dict["_meta"] = meta propagate.get_global_textmap().inject(meta) # Recreate the Pydantic model with the updated data diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index 057c41a95..bf0e7ce8e 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -180,6 +180,51 @@ def test_call_tool_sync_exception(mock_transport, mock_session): assert "Test exception" in result["content"][0]["text"] +def test_call_tool_sync_forwards_meta(mock_transport, mock_session): + """Test that call_tool_sync forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_call_tool_async_forwards_meta(mock_transport, mock_session): + """Test that call_tool_async forwards meta to ClientSession.call_tool.""" + mock_content = MCPTextContent(type="text", text="Test message") + mock_result = MCPCallToolResult(isError=False, content=[mock_content]) + mock_session.call_tool.return_value = mock_result + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"]) as client: + with ( + patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe, + patch("asyncio.wrap_future") as mock_wrap_future, + ): + mock_future = MagicMock() + mock_run_coroutine_threadsafe.return_value = mock_future + + async def mock_awaitable(): + return mock_result + + mock_wrap_future.return_value = mock_awaitable() + + result = await client.call_tool_async( + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta + ) + + mock_run_coroutine_threadsafe.assert_called_once() + + assert result["status"] == "success" + + @pytest.mark.asyncio @pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")]) async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status): @@ -584,7 +629,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -609,7 +654,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -635,7 +680,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -660,7 +705,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -683,7 +728,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -710,7 +755,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -762,7 +807,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" diff --git a/tests/strands/tools/mcp/test_mcp_instrumentation.py b/tests/strands/tools/mcp/test_mcp_instrumentation.py index 85d533403..9d44bba0c 100644 --- a/tests/strands/tools/mcp/test_mcp_instrumentation.py +++ b/tests/strands/tools/mcp/test_mcp_instrumentation.py @@ -328,7 +328,7 @@ class MockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() @classmethod @@ -431,6 +431,32 @@ def test_patch_mcp_client_injects_context_pydantic_model(self): # Verify the params object is still a MockPydanticParams (or dict if fallback occurred) assert hasattr(mock_request.root.params, "model_dump") or isinstance(mock_request.root.params, dict) + def test_patch_mcp_client_preserves_existing_meta_pydantic(self): + """Test that instrumentation preserves existing _meta values in Pydantic models.""" + mock_request = MagicMock() + mock_request.root.method = "tools/call" + + # Pydantic model with existing _meta (returned via by_alias=True) + mock_params = MockPydanticParams(_meta={"com.example/request_id": "abc-123"}, name="echo") + mock_request.root.params = mock_params + + with patch("strands.tools.mcp.mcp_instrumentation.wrap_function_wrapper") as mock_wrap: + mcp_instrumentation() + patch_function = mock_wrap.call_args_list[0][0][2] + + mock_wrapped = MagicMock() + + with patch.object(propagate, "get_global_textmap") as mock_textmap: + mock_textmap_instance = MagicMock() + mock_textmap.return_value = mock_textmap_instance + + patch_function(mock_wrapped, None, [mock_request], {}) + + # Verify the reconstructed params use the key "_meta" (alias) not "meta" (Python name) + validated_params = mock_request.root.params.model_dump(by_alias=True) + assert "_meta" in validated_params + assert validated_params["_meta"]["com.example/request_id"] == "abc-123" + def test_patch_mcp_client_injects_context_dict_params(self): """Test that the client patch injects OpenTelemetry context into dict params.""" # Create a mock request with tools/call method and dict params @@ -507,7 +533,7 @@ class FailingMockPydanticParams: def __init__(self, **data): self._data = data - def model_dump(self): + def model_dump(self, by_alias=False): return self._data.copy() def model_validate(self, data): diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index 363c588ee..9c901e885 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -20,6 +20,7 @@ from typing import Literal from mcp.server import FastMCP +from mcp.server.fastmcp import Context from mcp.types import BlobResourceContents, CallToolResult, EmbeddedResource, TextContent, TextResourceContents from pydantic import BaseModel @@ -48,6 +49,13 @@ def start_echo_server(): def echo(to_echo: str) -> str: return to_echo + @mcp.tool(description="Echos back the _meta received in the request", structured_output=False) + def echo_meta(ctx: Context) -> str: + meta = ctx.request_context.meta + if meta is None: + return json.dumps(None) + return json.dumps(meta.model_dump(exclude_none=True)) + # FastMCP automatically constructs structured output schema from method signature @mcp.tool(description="Echos response back with structured content", structured_output=True) def echo_with_structured_content(to_echo: str) -> EchoResponse: diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 130b35529..fe2b10df3 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -238,6 +238,74 @@ def test_mcp_client_without_structured_content(): assert result["content"] == [{"text": "SIMPLE_ECHO_TEST"}] +def test_call_tool_sync_with_meta(): + """Test that call_tool_sync forwards meta to the MCP server.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-meta-sync", + name="echo_meta", + arguments={}, + meta={"com.example/request_id": "abc-123"}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + assert received_meta["com.example/request_id"] == "abc-123" + + +@pytest.mark.asyncio +async def test_call_tool_async_with_meta(): + """Test that call_tool_async forwards meta to the MCP server.""" + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = await stdio_mcp_client.call_tool_async( + tool_use_id="test-meta-async", + name="echo_meta", + arguments={}, + meta={"com.example/request_id": "def-456"}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + assert received_meta["com.example/request_id"] == "def-456" + + +def test_instrumentation_preserves_meta_on_tool_call(): + """Test that OTel instrumentation sets _meta that reaches the MCP server.""" + from unittest.mock import MagicMock, patch + + # Mock the propagator to always inject a known value, bypassing the need for + # an active span on the background thread where send_request runs + mock_textmap = MagicMock() + mock_textmap.inject = lambda carrier, **kwargs: carrier.update({"traceparent": "00-abc-def-01"}) + + with patch("opentelemetry.propagate.get_global_textmap", return_value=mock_textmap): + stdio_mcp_client = MCPClient( + lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/mcp/echo_server.py"])) + ) + + with stdio_mcp_client: + result = stdio_mcp_client.call_tool_sync( + tool_use_id="test-instrumentation", + name="echo_meta", + arguments={}, + ) + + assert result["status"] == "success" + received_meta = json.loads(result["content"][0]["text"]) + # OTel instrumentation should have injected _meta with tracing context + assert received_meta is not None + assert isinstance(received_meta, dict) + assert received_meta["traceparent"] == "00-abc-def-01" + + @pytest.mark.skipif( condition=os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue", From a19e73dbedba5ca486f7c8039185078747c57773 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Tue, 7 Apr 2026 02:54:02 -0400 Subject: [PATCH 210/279] fix(anthropic): avoid Pydantic warnings for message_stop events (#2044) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/models/anthropic.py | 10 ++++- tests/strands/models/test_anthropic.py | 56 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index b5f6fcf91..f0be79bdd 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -407,7 +407,15 @@ async def stream( logger.debug("got response from model") async for event in stream: if event.type in AnthropicModel.EVENT_TYPES: - yield self.format_chunk(event.model_dump()) + if event.type == "message_stop": + # Build dict directly to avoid Pydantic serialization warnings + # when the message contains ParsedTextBlock objects (issue #1746) + yield self.format_chunk({ + "type": "message_stop", + "message": {"stop_reason": event.message.stop_reason}, + }) + else: + yield self.format_chunk(event.model_dump()) usage = event.message.usage # type: ignore yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()}) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index c5aff8062..8f4581655 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,5 +1,6 @@ import logging import unittest.mock +import warnings import anthropic import pydantic @@ -811,6 +812,7 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, ), unittest.mock.Mock( type="message_stop", + message=unittest.mock.Mock(stop_reason="tool_use"), model_dump=unittest.mock.Mock( return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}} ), @@ -933,3 +935,57 @@ def test_format_request_filters_location_source_document(model, model_id, max_to ] assert tru_request["messages"] == exp_messages assert "Location sources are not supported by Anthropic" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_message_stop_no_pydantic_warnings(anthropic_client, model, agenerator, alist): + """Verify no Pydantic serialization warnings are emitted for message_stop events. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1746. + """ + # Create a mock message_stop event where model_dump() would emit warnings + # The key is that the event has a .message attribute with .stop_reason + mock_message_stop = unittest.mock.Mock() + mock_message_stop.type = "message_stop" + mock_message_stop.message = unittest.mock.Mock() + mock_message_stop.message.stop_reason = "end_turn" + + # Make model_dump() emit a warning to simulate the problematic behavior + def model_dump_with_warning(): + warnings.warn( + "PydanticSerializationUnexpectedValue(Expected `ParsedTextBlock[TypeVar]`)", + UserWarning, + stacklevel=2, + ) + return {"type": mock_message_stop.type, "message": {"stop_reason": mock_message_stop.message.stop_reason}} + + mock_message_stop.model_dump = model_dump_with_warning + + mock_event_usage = unittest.mock.Mock( + message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), + ) + + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = agenerator([mock_message_stop, mock_event_usage]) + anthropic_client.messages.stream.return_value = mock_context + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + + # Capture warnings during streaming + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + response = model.stream(messages, None, None) + events = await alist(response) + + # Verify no Pydantic serialization warnings were emitted + pydantic_warnings = [ + w for w in caught_warnings if "PydanticSerializationUnexpectedValue" in str(w.message) + ] + assert len(pydantic_warnings) == 0, f"Unexpected Pydantic warnings: {pydantic_warnings}" + + # Verify the message_stop event was still processed correctly + assert {"messageStop": {"stopReason": mock_message_stop.message.stop_reason}} in events From 287c5b6c8ead73d9aa18b28dd601d7f89a4a50b2 Mon Sep 17 00:00:00 2001 From: mattdai01 <32076552+mattdai01@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:20:57 -0700 Subject: [PATCH 211/279] fix: propagate tool exceptions to spans so StatusCode.ERROR is set correctly (#2046) Co-authored-by: Matthew Dai --- src/strands/tools/executors/_executor.py | 21 ++++-- tests/strands/telemetry/test_tracer.py | 14 ++++ .../strands/tools/executors/test_executor.py | 70 +++++++++++++++++++ 3 files changed, 98 insertions(+), 7 deletions(-) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5825b3cdb..2c602a560 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -171,9 +171,15 @@ async def _stream( } after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message + agent, + None, + tool_use, + invocation_state, + cancel_result, + exception=Exception(cancel_message), + cancel_message=cancel_message, ) - yield ToolResultEvent(after_event.result) + yield ToolResultEvent(after_event.result, exception=after_event.exception) tool_results.append(after_event.result) return @@ -202,15 +208,16 @@ async def _stream( "content": [{"text": f"Unknown tool: {tool_name}"}], } + unknown_tool_error = Exception(f"Unknown tool: {tool_name}") after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( - agent, selected_tool, tool_use, invocation_state, result + agent, selected_tool, tool_use, invocation_state, result, exception=unknown_tool_error ) # Check if retry requested for unknown tool error # Use getattr because BidiAfterToolCallEvent doesn't have retry attribute if getattr(after_event, "retry", False): logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) continue - yield ToolResultEvent(after_event.result) + yield ToolResultEvent(after_event.result, exception=after_event.exception) tool_results.append(after_event.result) return if structured_output_context.is_enabled: @@ -258,7 +265,7 @@ async def _stream( logger.debug("tool_name=<%s> | retry requested, retrying tool call", tool_name) continue - yield ToolResultEvent(after_event.result) + yield ToolResultEvent(after_event.result, exception=after_event.exception) tool_results.append(after_event.result) return @@ -277,7 +284,7 @@ async def _stream( if getattr(after_event, "retry", False): logger.debug("tool_name=<%s> | retry requested after exception, retrying tool call", tool_name) continue - yield ToolResultEvent(after_event.result) + yield ToolResultEvent(after_event.result, exception=after_event.exception) tool_results.append(after_event.result) return @@ -338,7 +345,7 @@ async def _stream_with_trace( agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) - tracer.end_tool_call_span(tool_call_span, result) + tracer.end_tool_call_span(tool_call_span, result, error=result_event.exception) @abc.abstractmethod # pragma: no cover diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index bcd42b610..2d91b6216 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -707,6 +707,20 @@ def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): mock_span.end.assert_called_once() +def test_end_tool_call_span_with_error(mock_span): + """Test ending a tool call span with an explicit error sets StatusCode.ERROR.""" + tracer = Tracer() + error = ValueError("tool exploded") + tool_result = {"status": "error", "content": [{"text": "Error: tool exploded"}]} + + tracer.end_tool_call_span(mock_span, tool_result, error=error) + + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "error"}) + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "tool exploded") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 297aa66f3..34b37dab0 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -189,6 +189,7 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results tool_use=tool_use, invocation_state=invocation_state, result=exp_results[0], + exception=unittest.mock.ANY, ) assert tru_hook_after_event == exp_hook_after_event @@ -216,6 +217,7 @@ async def test_executor_stream_with_trace( tracer.end_tool_call_span.assert_called_once_with( tracer.start_tool_call_span.return_value, {"content": [{"text": "sunny"}], "status": "success", "toolUseId": "1"}, + error=None, ) cycle_trace.add_child.assert_called_once() @@ -901,3 +903,71 @@ def retry_once_on_unknown(event): assert len(tru_events) == 1 assert tru_events[0].tool_result["status"] == "error" assert "Unknown tool" in tru_events[0].tool_result["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_executor_stream_with_trace_error( + executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + """Test that _stream_with_trace passes the exception to end_tool_call_span when a tool fails.""" + tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream_with_trace(agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state) + + await alist(stream) + + tracer.end_tool_call_span.assert_called_once() + call_args = tracer.end_tool_call_span.call_args + assert call_args[0][1]["status"] == "error" + error_arg = call_args[1].get("error") + assert error_arg is not None + assert isinstance(error_arg, RuntimeError) + assert "Tool error" in str(error_arg) + + +@pytest.mark.asyncio +async def test_executor_stream_error_preserves_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with the exception preserved.""" + tool_use: ToolUse = {"name": "exception_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is not None + assert isinstance(result_event.exception, RuntimeError) + assert "Tool error" in str(result_event.exception) + + +@pytest.mark.asyncio +async def test_executor_stream_unknown_tool_has_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with exception for unknown tools.""" + tool_use: ToolUse = {"name": "nonexistent_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is not None + assert "Unknown tool" in str(result_event.exception) + + +@pytest.mark.asyncio +async def test_executor_stream_cancel_has_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with exception for cancelled tools.""" + + def cancel_callback(event): + event.cancel_tool = True + return event + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback) + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + events = await alist(stream) + result_event = events[-1] + assert isinstance(result_event, ToolResultEvent) + assert result_event.tool_result["status"] == "error" + assert result_event.exception is not None + assert "cancelled" in str(result_event.exception) From e7a217412bcc4bc52d70549930104fed15eb73e0 Mon Sep 17 00:00:00 2001 From: KKamJi Date: Wed, 8 Apr 2026 07:44:07 +0900 Subject: [PATCH 212/279] fix(docs): update 19 broken documentation links in README (#1906) Co-authored-by: Murat Kaan Meral --- README.md | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index fdb309f99..173adc006 100644 --- a/README.md +++ b/README.md @@ -169,21 +169,21 @@ response = agent("Tell me about Agentic AI") ``` Built-in providers: - - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - - [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/) - - [Cohere](https://strandsagents.com/latest/user-guide/concepts/model-providers/cohere/) - - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - - [llama.cpp](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamacpp/) - - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - - [MistralAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/mistral/) - - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - - [OpenAI Responses API](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - - [SageMaker](https://strandsagents.com/latest/user-guide/concepts/model-providers/sagemaker/) - - [Writer](https://strandsagents.com/latest/user-guide/concepts/model-providers/writer/) - -Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) + - [Amazon Bedrock](https://strandsagents.com/docs/user-guide/concepts/model-providers/amazon-bedrock/) + - [Anthropic](https://strandsagents.com/docs/user-guide/concepts/model-providers/anthropic/) + - [Gemini](https://strandsagents.com/docs/user-guide/concepts/model-providers/gemini/) + - [Cohere](https://strandsagents.com/docs/user-guide/concepts/model-providers/cohere/) + - [LiteLLM](https://strandsagents.com/docs/user-guide/concepts/model-providers/litellm/) + - [llama.cpp](https://strandsagents.com/docs/user-guide/concepts/model-providers/llamacpp/) + - [LlamaAPI](https://strandsagents.com/docs/user-guide/concepts/model-providers/llamaapi/) + - [MistralAI](https://strandsagents.com/docs/user-guide/concepts/model-providers/mistral/) + - [Ollama](https://strandsagents.com/docs/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/docs/user-guide/concepts/model-providers/openai/) + - [OpenAI Responses API](https://strandsagents.com/docs/user-guide/concepts/model-providers/openai/) + - [SageMaker](https://strandsagents.com/docs/user-guide/concepts/model-providers/sagemaker/) + - [Writer](https://strandsagents.com/docs/user-guide/concepts/model-providers/writer/) + +Custom providers can be implemented using [Custom Providers](https://strandsagents.com/docs/user-guide/concepts/model-providers/custom_model_provider/) ### Example tools @@ -202,7 +202,7 @@ It's also available on GitHub via [strands-agents/tools](https://github.com/stra > **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. -Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/docs/user-guide/concepts/bidirectional-streaming/quickstart/) guide. **Supported Model Providers:** - Amazon Nova Sonic (v1, v2) @@ -301,11 +301,11 @@ await agent.run( For detailed guidance & examples, explore our documentation: - [User Guide](https://strandsagents.com/) -- [Quick Start Guide](https://strandsagents.com/latest/user-guide/quickstart/) -- [Agent Loop](https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/) -- [Examples](https://strandsagents.com/latest/examples/) -- [API Reference](https://strandsagents.com/latest/api-reference/agent/) -- [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) +- [Quick Start Guide](https://strandsagents.com/docs/user-guide/quickstart/) +- [Agent Loop](https://strandsagents.com/docs/user-guide/concepts/agents/agent-loop/) +- [Examples](https://strandsagents.com/docs/examples/) +- [API Reference](https://strandsagents.com/docs/api/python/strands.agent.agent/) +- [Production & Deployment Guide](https://strandsagents.com/docs/user-guide/deploy/operating-agents-in-production/) ## Contributing ❤️ From 65b06d993da89802411aa2adc9bee509669b5ecd Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 8 Apr 2026 10:54:45 -0400 Subject: [PATCH 213/279] fix: enforce that the first message is a user message in the sliding window conversation manager (#2087) --- .../sliding_window_conversation_manager.py | 35 +++++++-- src/strands/models/anthropic.py | 10 ++- tests/strands/agent/test_agent.py | 25 ++++-- .../agent/test_conversation_manager.py | 78 ++++++++++++++++--- tests/strands/models/test_anthropic.py | 4 +- 5 files changed, 118 insertions(+), 34 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index b97de0b06..94446380b 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -167,9 +167,9 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A **kwargs: Additional keyword arguments for future extensibility. Raises: - ContextWindowOverflowException: If the context cannot be reduced further. - Such as when the conversation is already minimal or when tool result messages cannot be properly - converted. + ContextWindowOverflowException: If the context cannot be reduced further and a context overflow + error was provided (e is not None). When called during routine window management (e is None), + logs a warning and returns without modification. """ messages = agent.messages @@ -188,24 +188,43 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Find the next valid trim_index + # Find the next valid trim point that: + # 1. Starts with a user message (required by most model providers) + # 2. Does not start with an orphaned toolResult + # 3. Does not start with a toolUse unless its toolResult immediately follows while trim_index < len(messages): + # Must start with a user message + if messages[trim_index]["role"] != "user": + trim_index += 1 + continue + if ( # Oldest message cannot be a toolResult because it needs a toolUse preceding it any("toolResult" in content for content in messages[trim_index]["content"]) or ( # Oldest message can be a toolUse only if a toolResult immediately follows it. + # Note: toolUse content normally appears only in assistant messages, but this + # check is kept as a defensive safeguard for non-standard message formats. any("toolUse" in content for content in messages[trim_index]["content"]) - and trim_index + 1 < len(messages) - and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) + and not ( + trim_index + 1 < len(messages) + and any("toolResult" in content for content in messages[trim_index + 1]["content"]) + ) ) ): trim_index += 1 else: break else: - # If we didn't find a valid trim_index, then we throw - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + # If we didn't find a valid trim_index + if e is not None: + raise ContextWindowOverflowException("Unable to trim conversation context!") from e + logger.warning( + "window_size=<%s>, message_count=<%s> | unable to trim conversation context, no valid trim point found", + self.window_size, + len(messages), + ) + return # trim_index represents the number of messages being removed from the agents messages array self.removed_message_count += trim_index diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index f0be79bdd..6195f9ccd 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -410,10 +410,12 @@ async def stream( if event.type == "message_stop": # Build dict directly to avoid Pydantic serialization warnings # when the message contains ParsedTextBlock objects (issue #1746) - yield self.format_chunk({ - "type": "message_stop", - "message": {"stop_reason": event.message.stop_reason}, - }) + yield self.format_chunk( + { + "type": "message_stop", + "message": {"stop_reason": event.message.stop_reason}, + } + ) else: yield self.format_chunk(event.model_dump()) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c..0057c50a3 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1615,10 +1615,15 @@ def test_agent_restored_from_session_management_with_correct_index(): def test_agent_with_session_and_conversation_manager(): - mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + mock_model = MockedModelProvider( + [ + {"role": "assistant", "content": [{"text": "first"}]}, + {"role": "assistant", "content": [{"text": "second"}]}, + ] + ) mock_session_repository = MockedSessionRepository() session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) - conversation_manager = SlidingWindowConversationManager(window_size=1) + conversation_manager = SlidingWindowConversationManager(window_size=2) # Create an agent with a mocked model and session repository agent = Agent( session_manager=session_manager, @@ -1633,14 +1638,20 @@ def test_agent_with_session_and_conversation_manager(): agent("Hello!") - # After invoking, assert that the messages were persisted + # After first invocation: [user, assistant] — fits in window, no trimming assert len(mock_session_repository.list_messages("123", agent.agent_id)) == 2 - # Assert conversation manager reduced the messages - assert len(agent.messages) == 1 + assert len(agent.messages) == 2 + + agent("Second question") + + # After second invocation: [user, assistant, user, assistant] exceeds window_size=2 + # Conversation manager trims to 2 messages starting with a user message + assert len(agent.messages) == 2 + assert agent.messages[0]["role"] == "user" # Initialize another agent using the same session session_manager_2 = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) - conversation_manager_2 = SlidingWindowConversationManager(window_size=1) + conversation_manager_2 = SlidingWindowConversationManager(window_size=2) agent_2 = Agent( session_manager=session_manager_2, conversation_manager=conversation_manager_2, @@ -1648,7 +1659,7 @@ def test_agent_with_session_and_conversation_manager(): ) # Assert that the second agent was initialized properly, and that the messages of both agents are equal assert agent.messages == agent_2.messages - # Asser the conversation manager was initialized properly + # Assert the conversation manager was initialized properly assert agent.conversation_manager.removed_message_count == agent_2.conversation_manager.removed_message_count diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index fd88954e8..6db9897f1 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -78,6 +78,7 @@ def conversation_manager(request): ], ), # 5 - Remove dangling assistant message with tool use and user message without tool result + # Must start with a user message, so we skip the assistant message ( {"window_size": 3}, [ @@ -87,7 +88,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, ], [ - {"role": "assistant", "content": [{"text": "First response"}]}, {"role": "user", "content": [{"text": "Use a tool"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, ], @@ -107,19 +107,22 @@ def conversation_manager(request): ], ), # 7 - Message count above max window size - Preserve tool use/tool result pairs + # Cannot start with assistant or orphaned toolResult, so trim advances to next plain user message ( {"window_size": 2}, [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"text": "Hello"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"text": "Done"}]}, + {"role": "user", "content": [{"text": "Next"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, + {"role": "user", "content": [{"text": "Next"}]}, ], ), # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # Must start with user message (not assistant, not orphaned toolResult), so trim advances to plain user msg ( {"window_size": 3}, [ @@ -127,14 +130,14 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Response after tool use"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"text": "Response after tool use"}]}, + {"role": "user", "content": [{"text": "Follow up"}]}, ], ), # 9 - Test sliding window with multiple tool pairs that need preservation + # Must start with user message; orphaned toolResult is skipped, lands on plain user text ( {"window_size": 4}, [ @@ -144,11 +147,10 @@ def conversation_manager(request): {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, + {"role": "user", "content": [{"text": "Another question"}]}, ], [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"text": "Final response"}]}, + {"role": "user", "content": [{"text": "Another question"}]}, ], ), ], @@ -161,6 +163,43 @@ def test_apply_management(conversation_manager, messages, expected_messages): assert messages == expected_messages +def test_sliding_window_forces_user_message_start(): + """Test that trimmed conversation always starts with a user message (GitHub #2085).""" + manager = SlidingWindowConversationManager(window_size=3, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + {"role": "assistant", "content": [{"text": "Good"}]}, + {"role": "user", "content": [{"text": "Great"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert len(messages) == 3 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "How are you?"}] + + +def test_sliding_window_happy_path_preserves_window_size(): + """In a typical user/assistant conversation, trimming preserves close to window_size messages.""" + manager = SlidingWindowConversationManager(window_size=4, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "Second"}]}, + {"role": "assistant", "content": [{"text": "Second response"}]}, + {"role": "user", "content": [{"text": "Third"}]}, + {"role": "assistant", "content": [{"text": "Third response"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert len(messages) == 4 + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "Second"}] + + def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): manager = SlidingWindowConversationManager(1, False) messages = [ @@ -171,7 +210,22 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con test_agent = Agent(messages=messages) with pytest.raises(ContextWindowOverflowException): - manager.apply_management(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) + + assert messages == original_messages + + +def test_sliding_window_no_valid_trim_point_without_error_does_not_raise(): + """When no valid trim point exists during routine management (no error), messages are left unchanged.""" + manager = SlidingWindowConversationManager(1, False) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + test_agent = Agent(messages=messages) + + manager.apply_management(test_agent) assert messages == original_messages diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 8f4581655..d1f1df3b3 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -982,9 +982,7 @@ def model_dump_with_warning(): events = await alist(response) # Verify no Pydantic serialization warnings were emitted - pydantic_warnings = [ - w for w in caught_warnings if "PydanticSerializationUnexpectedValue" in str(w.message) - ] + pydantic_warnings = [w for w in caught_warnings if "PydanticSerializationUnexpectedValue" in str(w.message)] assert len(pydantic_warnings) == 0, f"Unexpected Pydantic warnings: {pydantic_warnings}" # Verify the message_stop event was still processed correctly From 46937d244b3de278c382859e5fdbdc7f1f6c5136 Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Wed, 8 Apr 2026 11:15:35 -0400 Subject: [PATCH 214/279] fix: forward meta to MCP task-augmented tool calls (#2081) Co-authored-by: agent-of-mkmeral --- src/strands/tools/mcp/mcp_client.py | 7 +- .../tools/mcp/test_mcp_client_tasks.py | 74 +++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 7574f4b65..11ed9c75e 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -592,7 +592,9 @@ def _create_call_tool_coroutine( async def _call_as_task() -> MCPCallToolResult: # When task-augmented execution is used, use the read_timeout_seconds parameter # (which is a timedelta) for the polling timeout. - return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout=read_timeout_seconds) + return await self._call_tool_as_task_and_poll_async( + name, arguments, poll_timeout=read_timeout_seconds, meta=meta + ) return _call_as_task() else: @@ -1100,6 +1102,7 @@ async def _call_tool_as_task_and_poll_async( arguments: dict[str, Any] | None = None, ttl: timedelta | None = None, poll_timeout: timedelta | None = None, + meta: dict[str, Any] | None = None, ) -> MCPCallToolResult: """Call a tool using task-augmented execution and poll until completion. @@ -1113,6 +1116,7 @@ async def _call_tool_as_task_and_poll_async( arguments: Optional arguments to pass to the tool. ttl: Task time-to-live. Uses configured value if not specified. poll_timeout: Timeout for polling. Uses configured value if not specified. + meta: Optional metadata to pass to the tool call per MCP spec (_meta). Returns: MCPCallToolResult: The final tool result after task completion. @@ -1133,6 +1137,7 @@ async def _call_tool_as_task_and_poll_async( name=name, arguments=arguments, ttl=ttl_ms, + meta=meta, ) task_id = create_result.task.taskId self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 01d3b2763..c21db9e28 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -214,3 +214,77 @@ async def poll(task_id): result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) assert result["status"] == "success" assert "Done" in result["content"][0].get("text", "") + + +class TestTaskMetaForwarding: + """Tests for meta parameter forwarding in task-augmented execution.""" + + def _setup_task_tool_with_meta(self, mock_session, tool_name: str) -> MagicMock: + """Helper to set up a mock task-enabled tool and return the experimental mock.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + return mock_session.experimental + + def test_call_tool_sync_forwards_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_sync forwards meta to call_tool_as_task.""" + experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool") + meta = {"com.example/request_id": "abc-123"} + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + client.call_tool_sync( + tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta + ) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") == meta + + @pytest.mark.asyncio + async def test_call_tool_async_forwards_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_async forwards meta to call_tool_as_task.""" + experimental = self._setup_task_tool_with_meta(mock_session, "meta_tool") + meta = {"com.example/trace_id": "xyz-456"} + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + await client.call_tool_async( + tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta + ) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") == meta + + def test_call_tool_sync_forwards_none_meta_to_task(self, mock_transport, mock_session): + """Test that call_tool_sync forwards None meta to call_tool_as_task when not provided.""" + experimental = self._setup_task_tool_with_meta(mock_session, "no_meta_tool") + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + client.call_tool_sync( + tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"} + ) + + experimental.call_tool_as_task.assert_called_once() + call_kwargs = experimental.call_tool_as_task.call_args + assert call_kwargs.kwargs.get("meta") is None From 2f9ffb18324792700ca533ead553599835e7cc68 Mon Sep 17 00:00:00 2001 From: Gautam Sirdeshmukh <54588697+gautamsirdeshmukh@users.noreply.github.com> Date: Wed, 8 Apr 2026 11:55:55 -0400 Subject: [PATCH 215/279] fix: handle premature stream termination for Anthropic (#1868) (#2047) --- src/strands/models/anthropic.py | 8 ++- tests/strands/models/test_anthropic.py | 87 +++++++++++++++++++++----- 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 6195f9ccd..818a8f14c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -419,8 +419,12 @@ async def stream( else: yield self.format_chunk(event.model_dump()) - usage = event.message.usage # type: ignore - yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()}) + try: + message_snapshot = await stream.get_final_message() + except AssertionError as e: + logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e) + else: + yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()}) except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index d1f1df3b3..78a5ea693 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -53,6 +53,24 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel +def generate_mock_stream_context(events, final_message=None): + mock_stream = unittest.mock.AsyncMock() + + async def mock_aiter(self): + for event in events: + yield event + + mock_stream.__aiter__ = mock_aiter + if isinstance(final_message, Exception): + mock_stream.get_final_message.side_effect = final_message + elif final_message: + mock_stream.get_final_message.return_value = final_message + + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = mock_stream + return mock_context + + def test__init__model_configs(anthropic_client, model_id, max_tokens): _ = anthropic_client @@ -693,7 +711,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(anthropic_client, model, agenerator, alist): +async def test_stream(anthropic_client, model, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -714,9 +732,14 @@ async def test_stream(anthropic_client, model, agenerator, alist): ), ) - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3]) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event_1, mock_event_2, mock_event_3], + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), + ) messages = [{"role": "user", "content": [{"text": "hello"}]}] response = model.stream(messages, None, None) @@ -739,6 +762,42 @@ async def test_stream(anthropic_client, model, agenerator, alist): anthropic_client.messages.stream.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_early_termination(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + mock_event = unittest.mock.Mock( + type="message_start", + model_dump=lambda: {"type": "message_start"}, + ) + + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert len(tru_events) == 1 + assert "messageStart" in tru_events[0] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_empty(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert tru_events == [] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + @pytest.mark.asyncio async def test_stream_rate_limit_error(anthropic_client, model, alist): anthropic_client.messages.stream.side_effect = anthropic.RateLimitError( @@ -781,7 +840,7 @@ async def test_stream_bad_request_error(anthropic_client, model): @pytest.mark.asyncio -async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist): +async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -817,18 +876,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}} ), ), - unittest.mock.Mock( - message=unittest.mock.Mock( - usage=unittest.mock.Mock( - model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) - ), - ), - ), ] - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator(events) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + events, + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) + ), + ), + ) stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream) From f58117c400712a7ac8ef849071db207b94e3cbf6 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 8 Apr 2026 12:06:38 -0400 Subject: [PATCH 216/279] ci: add weekly markdown link check workflow (#2088) --- .github/workflows/check-markdown-links.yml | 51 ++++++++++++++++++++++ .markdown-link-check.json | 6 +++ 2 files changed, 57 insertions(+) create mode 100644 .github/workflows/check-markdown-links.yml create mode 100644 .markdown-link-check.json diff --git a/.github/workflows/check-markdown-links.yml b/.github/workflows/check-markdown-links.yml new file mode 100644 index 000000000..2ac596190 --- /dev/null +++ b/.github/workflows/check-markdown-links.yml @@ -0,0 +1,51 @@ +name: Check Markdown Links + +on: + schedule: + - cron: '0 9 * * 1' # Every Monday at 9am UTC + workflow_dispatch: # Allow manual trigger + +jobs: + check-links: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: gaurav-nelson/github-action-markdown-link-check@3c3b66f1f7d0900e37b71eca45b63ea9eedfce31 # v1.0.17 + id: link-check + with: + use-quiet-mode: 'yes' + use-verbose-mode: 'yes' + config-file: '.markdown-link-check.json' + continue-on-error: true + + - name: Create issue if links are broken + if: steps.link-check.outcome == 'failure' + uses: actions/github-script@v7 + with: + script: | + const title = '🔗 Broken markdown links detected'; + const label = 'broken-links'; + + // Check for existing open issue to avoid duplicates + const existing = await github.rest.issues.listForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + labels: label, + }); + + if (existing.data.length > 0) { + console.log(`Issue already exists: #${existing.data[0].number}`); + return; + } + + const runUrl = `${context.serverUrl}/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`; + + await github.rest.issues.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title, + body: `The weekly markdown link check found broken links.\n\nSee the [workflow run](${runUrl}) for details.`, + labels: [label], + }); diff --git a/.markdown-link-check.json b/.markdown-link-check.json new file mode 100644 index 000000000..a03e7e0a9 --- /dev/null +++ b/.markdown-link-check.json @@ -0,0 +1,6 @@ +{ + "retryOn429": true, + "retryCount": 3, + "fallbackRetryDelay": "30s", + "aliveStatusCodes": [200, 206] +} From 289e22aff8dd3bce58361a63b08ee5aeb2524851 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 8 Apr 2026 13:24:04 -0400 Subject: [PATCH 217/279] fix(test): update session integ test for sliding window conversation manager (#2092) --- tests_integ/test_session.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tests_integ/test_session.py b/tests_integ/test_session.py index 0d4fe9fe1..6b50aa508 100644 --- a/tests_integ/test_session.py +++ b/tests_integ/test_session.py @@ -61,31 +61,42 @@ def test_agent_with_file_session(temp_dir): def test_agent_with_file_session_and_conversation_manager(temp_dir): - # Set up the session manager and add an agent + # Use window_size=2 because the sliding window now enforces that the first remaining + # message after trimming is a user message (#2087). With a simple (no-tool) turn producing + # [user, assistant], window_size=1 can never trim (the sole remaining message would be + # assistant). window_size=2 keeps a valid [user, assistant] pair after trimming. test_session_id = str(uuid4()) - # Create a session session_manager = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) try: agent = Agent( - session_manager=session_manager, conversation_manager=SlidingWindowConversationManager(window_size=1) + session_manager=session_manager, conversation_manager=SlidingWindowConversationManager(window_size=2) ) + # First call: 2 messages [user, assistant], fits in window — no trim agent("Hello!") + assert len(agent.messages) == 2 assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 2 - # Conversation Manager reduced messages - assert len(agent.messages) == 1 - # After agent is persisted and run, restore the agent and run it again + # Second call: 4 messages, exceeds window, trimmed back to 2 [user, assistant] + agent("Hi again!") + assert len(agent.messages) == 2 + assert agent.conversation_manager.removed_message_count == 2 + # Session manager persists ALL messages even though agent memory was trimmed + assert len(session_manager.list_messages(test_session_id, agent.agent_id)) == 4 + + # Restore agent from session — should load trimmed state session_manager_2 = FileSessionManager(session_id=test_session_id, storage_dir=temp_dir) agent_2 = Agent( - session_manager=session_manager_2, conversation_manager=SlidingWindowConversationManager(window_size=1) + session_manager=session_manager_2, conversation_manager=SlidingWindowConversationManager(window_size=2) ) - assert len(agent_2.messages) == 1 - assert agent_2.conversation_manager.removed_message_count == 1 + assert len(agent_2.messages) == 2 + assert agent_2.conversation_manager.removed_message_count == 2 + + # Third call on restored agent: triggers another trim agent_2("Hello!") - assert len(agent_2.messages) == 1 - assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 4 + assert len(agent_2.messages) == 2 + assert agent_2.conversation_manager.removed_message_count == 4 + assert len(session_manager_2.list_messages(test_session_id, agent_2.agent_id)) == 6 finally: - # Delete the session session_manager.delete_session(test_session_id) assert session_manager.read_session(test_session_id) is None From cd5da4f5249c4e72954177b7fa08069317713278 Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Wed, 8 Apr 2026 15:28:41 -0400 Subject: [PATCH 218/279] fix(test): fix anthropic stream test mock missing get_final_message (#2094) Co-authored-by: Murat Kaan Meral --- tests/strands/models/test_anthropic.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 78a5ea693..74037fc00 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -995,7 +995,7 @@ def test_format_request_filters_location_source_document(model, model_id, max_to @pytest.mark.asyncio -async def test_stream_message_stop_no_pydantic_warnings(anthropic_client, model, agenerator, alist): +async def test_stream_message_stop_no_pydantic_warnings(anthropic_client, model, alist): """Verify no Pydantic serialization warnings are emitted for message_stop events. Regression test for https://github.com/strands-agents/sdk-python/issues/1746. @@ -1018,16 +1018,12 @@ def model_dump_with_warning(): mock_message_stop.model_dump = model_dump_with_warning - mock_event_usage = unittest.mock.Mock( - message=unittest.mock.Mock( - usage=unittest.mock.Mock( - model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, - ) - ), + final_message = unittest.mock.Mock() + final_message.usage = unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, ) - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator([mock_message_stop, mock_event_usage]) + mock_context = generate_mock_stream_context([mock_message_stop], final_message=final_message) anthropic_client.messages.stream.return_value = mock_context messages = [{"role": "user", "content": [{"text": "hello"}]}] From 70b0989b0d96ab1cb1a8ebdf01cb2cc65ae44f11 Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Thu, 9 Apr 2026 09:28:17 -0400 Subject: [PATCH 219/279] feat(hooks): accept callable hook callbacks in Agent constructor (#1992) Co-authored-by: agent-of-mkmeral Co-authored-by: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/agent.py | 14 +++++-- tests/strands/agent/test_agent_hooks.py | 52 +++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..439471a84 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -132,7 +132,7 @@ def __init__( description: str | None = None, state: AgentState | dict | None = None, plugins: list[Plugin] | None = None, - hooks: list[HookProvider] | None = None, + hooks: list[HookProvider | HookCallback] | None = None, session_manager: SessionManager | None = None, structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, @@ -187,7 +187,8 @@ def __init__( Plugins are initialized with the agent instance after construction and can register hooks, modify agent attributes, or perform other setup tasks. Defaults to None. - hooks: hooks to be added to the agent hook registry + hooks: Hooks to be added to the agent hook registry. Accepts HookProvider instances + or plain callable hook callbacks (functions with typed event parameters). Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. @@ -341,7 +342,14 @@ def __init__( if hooks: for hook in hooks: - self.hooks.add_hook(hook) + if isinstance(hook, HookProvider): + self.hooks.add_hook(hook) + elif callable(hook): + self.hooks.add_callback(None, hook) + else: + raise ValueError( + f"Invalid hook: {hook!r}. Must be a HookProvider instance or a callable hook callback." + ) # Register built-in plugins self._plugin_registry.add_and_init(_ModelPlugin()) diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 1da245d70..3a40d69a8 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -1021,3 +1021,55 @@ def interrupt_tool(event: BeforeToolCallEvent): assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Final response" assert agent._interrupt_state.activated is False + + +def test_hooks_param_accepts_callable(): + """Verify that a plain callable can be passed via hooks parameter.""" + events_received = [] + + def my_callback(event: AgentInitializedEvent) -> None: + events_received.append(event) + + agent = Agent(hooks=[my_callback], callback_handler=None) + + assert len(events_received) == 1 + assert isinstance(events_received[0], AgentInitializedEvent) + assert events_received[0].agent is agent + + +def test_hooks_param_accepts_mixed_list(): + """Verify that a mix of HookProviders and callables can be passed.""" + callback_events = [] + + def my_callback(event: AgentInitializedEvent) -> None: + callback_events.append(event) + + provider = MockHookProvider(event_types=[AgentInitializedEvent]) + + agent = Agent(hooks=[provider, my_callback], callback_handler=None) + + assert len(callback_events) == 1 + assert callback_events[0].agent is agent + length, _ = provider.get_events() + assert length == 1 + + +def test_hooks_param_invalid_hook_raises_error(): + """Verify that passing an invalid hook raises ValueError.""" + with pytest.raises(ValueError, match="Invalid hook"): + Agent(hooks=["not_a_hook"], callback_handler=None) # type: ignore + + +def test_hooks_param_callable_invoked_during_lifecycle(): + """Verify callable hooks fire during agent lifecycle.""" + before_events = [] + + def on_before(event: BeforeInvocationEvent) -> None: + before_events.append(event) + + mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "Hello!"}]}]) + agent = Agent(model=mock_model, hooks=[on_before], callback_handler=None) + agent("test") + + assert len(before_events) == 1 + assert isinstance(before_events[0], BeforeInvocationEvent) From 762fba29a3bb954bd480be2604a2e3d52d1b8c1c Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Thu, 9 Apr 2026 09:42:26 -0400 Subject: [PATCH 220/279] fix: handle missing optional fields in non-streaming citation conversion (#2098) Co-authored-by: agent-of-mkmeral --- src/strands/models/bedrock.py | 16 +-- tests/strands/models/test_bedrock.py | 141 +++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 7 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 5b7a2f34e..bfb7b1ede 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -966,13 +966,15 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera } for citation in content["citationsContent"]["citations"]: - # Then emit citation metadata (for structure) - - citation_metadata: CitationsDelta = { - "title": citation["title"], - "location": citation["location"], - "sourceContent": citation["sourceContent"], - } + # Emit citation metadata, only including fields that are present + # Nova grounding may omit title/sourceContent + citation_metadata: CitationsDelta = {} + if "title" in citation: + citation_metadata["title"] = citation["title"] + if "location" in citation: + citation_metadata["location"] = citation["location"] + if "sourceContent" in citation: + citation_metadata["sourceContent"] = citation["sourceContent"] yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} # Yield contentBlockStop event diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 9c565d4f4..cd7016488 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2823,3 +2823,144 @@ def test_guardrail_latest_message_disabled_does_not_wrap(model): assert "text" in formatted assert "guardContent" not in formatted + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_missing_optional_fields(bedrock_client, model, alist): + """Test that _convert_non_streaming_to_streaming handles citations missing optional fields. + + Nova grounding returns citations with only url/domain but no title field. The conversion + should not crash with KeyError when optional fields like title, location, or sourceContent + are missing from the citation response. + """ + # Simulate a non-streaming response with citations missing the 'title' field + # This is what Nova grounding returns: url+domain in location, no title + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "Top shoe brands include Nike and Adidas."}], + "citations": [ + { + "location": { + "web": { + "url": "https://example.com/shoes", + "domain": "example.com", + } + }, + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 20}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + # Should have: messageStart, contentBlockDelta (text + citation), contentBlockStop, messageStop, metadata + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + # title should NOT be present since the source didn't have it + assert "title" not in citation + # location should be present + assert "location" in citation + # sourceContent should NOT be present since the source didn't have it + assert "sourceContent" not in citation + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_all_fields_present(bedrock_client, model, alist): + """Test that _convert_non_streaming_to_streaming correctly includes all fields when present.""" + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "Nike is a top shoe brand."}], + "citations": [ + { + "title": "Top Shoe Brands", + "location": { + "web": { + "url": "https://example.com/shoes", + "domain": "example.com", + } + }, + "sourceContent": [{"text": "Nike is a leading brand"}], + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 20}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + assert citation["title"] == "Top Shoe Brands" + assert citation["location"] == {"web": {"url": "https://example.com/shoes", "domain": "example.com"}} + assert citation["sourceContent"] == [{"text": "Nike is a leading brand"}] + + +@pytest.mark.asyncio +async def test_non_streaming_citations_with_only_location(bedrock_client, model, alist): + """Test citations with only location field (no title, no sourceContent).""" + non_streaming_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "citationsContent": { + "citations": [ + { + "location": { + "web": { + "url": "https://example.com", + "domain": "example.com", + } + }, + }, + ], + } + } + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 5, "outputTokens": 10}, + } + + events = list(model._convert_non_streaming_to_streaming(non_streaming_response)) + + citation_deltas = [ + e for e in events if "contentBlockDelta" in e and "citation" in e.get("contentBlockDelta", {}).get("delta", {}) + ] + assert len(citation_deltas) == 1 + + citation = citation_deltas[0]["contentBlockDelta"]["delta"]["citation"] + assert citation["location"] == {"web": {"url": "https://example.com", "domain": "example.com"}} + assert "title" not in citation + assert "sourceContent" not in citation From ca6f599d403a0d190d4f102f608e21034c79a439 Mon Sep 17 00:00:00 2001 From: Giulio Leone Date: Thu, 9 Apr 2026 20:35:24 +0200 Subject: [PATCH 221/279] fix(telemetry): add common gen_ai attributes to event loop cycle spans (#1973) Co-authored-by: giulio-leone --- src/strands/telemetry/tracer.py | 7 ++++--- tests/strands/telemetry/test_tracer.py | 10 +++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 19a163f5c..1ae122db1 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -526,9 +526,10 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: dict[str, AttributeValue] = { - "event_loop.cycle_id": event_loop_cycle_id, - } + attributes: dict[str, AttributeValue] = self._get_common_attributes( + operation_name="execute_event_loop_cycle" + ) + attributes["event_loop.cycle_id"] = event_loop_cycle_id if custom_trace_attributes: attributes.update(custom_trace_attributes) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 2d91b6216..e5f7b6472 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -743,6 +743,8 @@ def test_start_event_loop_cycle_span(mock_tracer): mock_span.set_attributes.assert_called_once_with( { + "gen_ai.operation.name": "execute_event_loop_cycle", + "gen_ai.system": "strands-agents", "event_loop.cycle_id": "cycle-123", "request_id": "req-456", "trace_level": "debug", @@ -772,7 +774,13 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" - mock_span.set_attributes.assert_called_once_with({"event_loop.cycle_id": "cycle-123"}) + mock_span.set_attributes.assert_called_once_with( + { + "gen_ai.operation.name": "execute_event_loop_cycle", + "gen_ai.provider.name": "strands-agents", + "event_loop.cycle_id": "cycle-123", + } + ) mock_span.add_event.assert_any_call( "gen_ai.client.inference.operation.details", attributes={ From d27b8ff79d76335fae8c683418ac50c10d1b7f55 Mon Sep 17 00:00:00 2001 From: en-yao <121856029+en-yao@users.noreply.github.com> Date: Fri, 10 Apr 2026 04:31:10 +0800 Subject: [PATCH 222/279] fix(telemetry): use per-invocation usage in agent span attributes (#2017) --- src/strands/telemetry/tracer.py | 31 ++++++++++----- tests/strands/telemetry/test_tracer.py | 52 ++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 1ae122db1..d5d399f95 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -83,8 +83,8 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. - Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", - respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. + Both attributes are controlled by including "gen_ai_latest_experimental", "gen_ai_tool_definitions", + or "gen_ai_use_latest_invocation_tokens", respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ def __init__(self) -> None: @@ -100,6 +100,7 @@ def __init__(self) -> None: ## To-do: should not set below attributes directly, use env var instead self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values self._include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values + self._use_latest_invocation_tokens = "gen_ai_use_latest_invocation_tokens" in opt_in_values def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. @@ -690,16 +691,26 @@ def end_agent_span( if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): if self.is_langfuse: attributes.update({"langfuse.observation.type": "span"}) - accumulated_usage = response.metrics.accumulated_usage + if self._use_latest_invocation_tokens: + latest_invocation = response.metrics.latest_agent_invocation + if latest_invocation is None: + logger.warning( + "latest_agent_invocation is None despite _use_latest_invocation_tokens being set" + ) + usage: Usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + else: + usage = latest_invocation.usage + else: + usage = response.metrics.accumulated_usage attributes.update( { - "gen_ai.usage.prompt_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.completion_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.input_tokens": accumulated_usage["inputTokens"], - "gen_ai.usage.output_tokens": accumulated_usage["outputTokens"], - "gen_ai.usage.total_tokens": accumulated_usage["totalTokens"], - "gen_ai.usage.cache_read_input_tokens": accumulated_usage.get("cacheReadInputTokens", 0), - "gen_ai.usage.cache_write_input_tokens": accumulated_usage.get("cacheWriteInputTokens", 0), + "gen_ai.usage.prompt_tokens": usage["inputTokens"], + "gen_ai.usage.completion_tokens": usage["outputTokens"], + "gen_ai.usage.input_tokens": usage["inputTokens"], + "gen_ai.usage.output_tokens": usage["outputTokens"], + "gen_ai.usage.total_tokens": usage["totalTokens"], + "gen_ai.usage.cache_read_input_tokens": usage.get("cacheReadInputTokens", 0), + "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } ) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index e5f7b6472..6b622bb3e 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1,4 +1,5 @@ import json +import logging import os from datetime import date, datetime, timezone from unittest import mock @@ -1053,6 +1054,57 @@ def test_end_agent_span_latest_conventions(mock_span, monkeypatch): mock_span.end.assert_called_once() +def test_end_agent_span_uses_per_invocation_usage_when_opted_in(mock_span, monkeypatch): + """Test that agent span reports per-invocation usage when gen_ai_use_latest_invocation_tokens is set.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_use_latest_invocation_tokens") + tracer = Tracer() + + mock_invocation = mock.MagicMock() + mock_invocation.usage = {"inputTokens": 100, "outputTokens": 50, "totalTokens": 150} + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 1000, "outputTokens": 500, "totalTokens": 1500} + mock_metrics.latest_agent_invocation = mock_invocation + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + call_args = mock_span.set_attributes.call_args[0][0] + assert call_args["gen_ai.usage.input_tokens"] == 100 + assert call_args["gen_ai.usage.output_tokens"] == 50 + assert call_args["gen_ai.usage.total_tokens"] == 150 + assert call_args["gen_ai.usage.prompt_tokens"] == 100 + assert call_args["gen_ai.usage.completion_tokens"] == 50 + + +def test_end_agent_span_warns_when_opted_in_but_no_invocations(mock_span, monkeypatch, caplog): + """Test warning and zero usage when opted in but no agent invocations exist.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_use_latest_invocation_tokens") + tracer = Tracer() + + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 200, "outputTokens": 100, "totalTokens": 300} + mock_metrics.latest_agent_invocation = None + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + with caplog.at_level(logging.WARNING): + tracer.end_agent_span(mock_span, mock_response) + + assert "latest_agent_invocation is None" in caplog.text + call_args = mock_span.set_attributes.call_args[0][0] + assert call_args["gen_ai.usage.input_tokens"] == 0 + assert call_args["gen_ai.usage.output_tokens"] == 0 + assert call_args["gen_ai.usage.total_tokens"] == 0 + + def test_end_model_invoke_span_with_cache_metrics(mock_span): """Test ending a model invoke span with cache metrics.""" tracer = Tracer() From 50b2c799236d3ebf0e659b8c4c0eaceec72564fe Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Fri, 10 Apr 2026 11:56:34 -0400 Subject: [PATCH 223/279] feat(a2a): add client_config param and deprecate a2a_client_factory (#2103) Co-authored-by: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> Co-authored-by: agent-of-mkmeral --- src/strands/agent/a2a_agent.py | 61 +++- tests/strands/agent/test_a2a_agent.py | 392 +++++++++++++++++++++++--- 2 files changed, 395 insertions(+), 58 deletions(-) diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index e18da2f4a..eef47e3b4 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -6,7 +6,9 @@ A2AAgent can be used to get the Agent Card and interact with the agent. """ +import dataclasses import logging +import warnings from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any @@ -38,6 +40,7 @@ def __init__( name: str | None = None, description: str | None = None, timeout: int = _DEFAULT_TIMEOUT, + client_config: ClientConfig | None = None, a2a_client_factory: ClientFactory | None = None, ): """Initialize A2A agent. @@ -47,15 +50,34 @@ def __init__( name: Agent name. If not provided, will be populated from agent card. description: Agent description. If not provided, will be populated from agent card. timeout: Timeout for HTTP operations in seconds (defaults to 300). - a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, - it will be used to create the A2A client after discovering the agent card. - Note: When providing a custom factory, you are responsible for managing - the lifecycle of any httpx client it uses. + client_config: A2A ``ClientConfig`` for authentication and transport settings. + The ``httpx_client`` configured here is used for both card discovery and + message sending, enabling authenticated endpoints (SigV4, OAuth, bearer tokens). + When providing an ``httpx_client``, you are responsible for configuring its timeout. + a2a_client_factory: Deprecated. Use ``client_config`` instead. + + Raises: + ValueError: If both ``client_config`` and ``a2a_client_factory`` are provided. """ + if client_config is not None and a2a_client_factory is not None: + raise ValueError( + "Cannot provide both client_config and a2a_client_factory. " + "Use client_config (recommended) or a2a_client_factory (deprecated), not both." + ) + + if a2a_client_factory is not None: + warnings.warn( + "a2a_client_factory is deprecated. Use client_config instead. " + "a2a_client_factory will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.endpoint = endpoint self.name = name self.description = description self.timeout = timeout + self._client_config: ClientConfig | None = client_config self._agent_card: AgentCard | None = None self._a2a_client_factory: ClientFactory | None = a2a_client_factory @@ -160,9 +182,11 @@ async def stream_async( async def get_agent_card(self) -> AgentCard: """Fetch and return the remote agent's card. - This method eagerly fetches the agent card from the remote endpoint, - populating name and description if not already set. The card is cached - after the first fetch. + Eagerly fetches the agent card from the remote endpoint, populating name and description + if not already set. The card is cached after the first fetch. + + When ``client_config`` is provided with an ``httpx_client``, that client is used for + card resolution, enabling authenticated card discovery (e.g., SigV4, OAuth, bearer tokens). Returns: The remote agent's AgentCard containing name, description, capabilities, skills, etc. @@ -170,16 +194,20 @@ async def get_agent_card(self) -> AgentCard: if self._agent_card is not None: return self._agent_card - async with httpx.AsyncClient(timeout=self.timeout) as client: - resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + if self._client_config is not None and self._client_config.httpx_client is not None: + resolver = A2ACardResolver(httpx_client=self._client_config.httpx_client, base_url=self.endpoint) self._agent_card = await resolver.get_agent_card() + else: + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() # Populate name from card if not set - if self.name is None and self._agent_card.name: + if self.name is None and self._agent_card.name is not None: self.name = self._agent_card.name # Populate description from card if not set - if self.description is None and self._agent_card.description: + if self.description is None and self._agent_card.description is not None: self.description = self._agent_card.description logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) @@ -189,8 +217,9 @@ async def get_agent_card(self) -> AgentCard: async def _get_a2a_client(self) -> AsyncIterator[Any]: """Get A2A client for sending messages. - If a custom factory was provided, uses that (caller manages httpx lifecycle). - Otherwise creates a per-call httpx client with proper cleanup. + If a deprecated factory was provided, delegates to it for client creation. + If client_config was provided, uses it directly — ClientFactory handles defaults. + Otherwise creates a managed httpx client with the agent's timeout. Yields: Configured A2A client instance. @@ -201,6 +230,12 @@ async def _get_a2a_client(self) -> AsyncIterator[Any]: yield self._a2a_client_factory.create(agent_card) return + if self._client_config is not None: + config = dataclasses.replace(self._client_config, streaming=True) + yield ClientFactory(config).create(agent_card) + return + + # No client_config — create a managed httpx client, consistent with get_agent_card() path async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: config = ClientConfig(httpx_client=httpx_client, streaming=True) yield ClientFactory(config).create(agent_card) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index 26a34476d..d918033e5 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -1,10 +1,12 @@ """Tests for A2AAgent class.""" +import warnings from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest +from a2a.client import ClientConfig from a2a.types import AgentCard, Message, Part, Role, TextPart from strands.agent.a2a_agent import A2AAgent @@ -58,6 +60,9 @@ async def mock_a2a_client_context(send_message_func): yield mock_httpx_class, mock_factory_class +# === Init Tests === + + def test_init_with_defaults(): """Test initialization with default parameters.""" agent = A2AAgent(endpoint="http://localhost:8000") @@ -81,11 +86,41 @@ def test_init_with_custom_timeout(): assert agent.timeout == 600 +def test_init_with_client_config(): + """Test initialization with client_config.""" + config = ClientConfig() + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + assert agent._client_config is config + + def test_init_with_external_a2a_client_factory(): - """Test initialization with external A2A client factory.""" + """Test initialization with external A2A client factory emits deprecation warning.""" external_factory = MagicMock() - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - assert agent._a2a_client_factory is external_factory + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "a2a_client_factory is deprecated" in str(w[0].message) + assert "client_config" in str(w[0].message) + + +def test_init_with_both_client_config_and_factory_raises(): + """Test that providing both client_config and factory raises ValueError.""" + config = ClientConfig() + factory = MagicMock() + with pytest.raises(ValueError, match="Cannot provide both client_config and a2a_client_factory"): + A2AAgent(endpoint="http://localhost:8000", client_config=config, a2a_client_factory=factory) + + +def test_init_no_asyncio_lock(): + """Test that A2AAgent does not create an asyncio.Lock in __init__.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert not hasattr(agent, "_card_lock") + + +# === Card Resolution Tests === @pytest.mark.asyncio @@ -147,6 +182,314 @@ async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_c assert agent.description == "Custom description" +@pytest.mark.asyncio +async def test_get_agent_card_handles_empty_string_name_and_description(mock_httpx_client): + """Test that empty string name/description from card are preserved (not treated as None).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "" + mock_card.description = "" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Empty strings should be set (not treated as falsy/None) + assert agent.name == "" + assert agent.description == "" + + +@pytest.mark.asyncio +async def test_get_agent_card_with_client_config_uses_auth_client(): + """Test that client_config's httpx_client is used for card resolution (fixes auth bug).""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + resolver_httpx_client = None + + def track_resolver_init(*, httpx_client, base_url): + nonlocal resolver_httpx_client + resolver_httpx_client = httpx_client + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + return mock_resolver + + with patch("strands.agent.a2a_agent.A2ACardResolver", side_effect=track_resolver_init): + await agent.get_agent_card() + + # CRITICAL: Verify the authenticated client was used for card resolution + assert resolver_httpx_client is mock_auth_client, ( + "Bug not fixed: authenticated httpx client was not used for card resolution" + ) + + +@pytest.mark.asyncio +async def test_get_agent_card_without_client_config_uses_default_httpx(mock_httpx_client): + """Test that card resolution uses bare httpx when no client_config is provided.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use bare httpx with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_factory_only_uses_default_httpx(mock_httpx_client): + """Test that deprecated factory without client_config still uses bare httpx for card resolution.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + mock_factory = MagicMock() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Factory alone does NOT provide auth for card resolution — uses bare httpx + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_client_config_without_httpx_uses_default(mock_httpx_client): + """Test that client_config without httpx_client falls through to managed httpx (same as no config).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + config = ClientConfig(polling=True) # No httpx_client + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use managed httpx with timeout (same as no config path) + mock_httpx_class.assert_called_once_with(timeout=300) + + +# === Client Creation Tests === + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_preserves_user_settings(mock_agent_card): + """Test that _get_a2a_client preserves all user ClientConfig settings via dataclasses.replace.""" + mock_auth_client = MagicMock() + config = ClientConfig( + httpx_client=mock_auth_client, + streaming=False, # user set this to False + polling=True, + supported_transports=["jsonrpc"], + ) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify factory was created with a config that preserves user settings + mock_factory_class.assert_called_once() + created_config = mock_factory_class.call_args[0][0] + assert created_config.httpx_client is mock_auth_client + assert created_config.streaming is True # overridden to True + assert created_config.polling is True # preserved + assert created_config.supported_transports == ["jsonrpc"] # preserved + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_does_not_mutate_original(mock_agent_card): + """Test that _get_a2a_client does not mutate the original client_config.""" + config = ClientConfig(streaming=False) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Original config should NOT be mutated + assert config.streaming is False + + +@pytest.mark.asyncio +async def test_get_a2a_client_config_without_httpx_delegates_to_factory(mock_agent_card): + """Test that _get_a2a_client delegates to ClientFactory when config has no httpx_client. + + ClientFactory handles creating a default httpx client internally. We just pass + the config with streaming=True and let the factory do its job. + """ + config = ClientConfig(polling=True, supported_transports=["jsonrpc"]) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config, timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Should pass config directly to ClientFactory — factory handles httpx defaults + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + assert created_config.polling is True + assert created_config.supported_transports == ["jsonrpc"] + assert created_config.httpx_client is None # factory handles default + + +@pytest.mark.asyncio +async def test_send_message_uses_provided_factory(mock_agent_card): + """Test _send_message uses provided factory instead of creating per-call client.""" + external_factory = MagicMock() + mock_a2a_client = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send_message + external_factory.create.return_value = mock_a2a_client + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + # Consume the async iterator + async for _ in agent._send_message("Hello"): + pass + + external_factory.create.assert_called_once_with(mock_agent_card) + + +@pytest.mark.asyncio +async def test_send_message_uses_client_config_httpx_client(mock_agent_card): + """Test _send_message uses client_config's httpx_client for client creation.""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + mock_a2a_client = MagicMock() + + async def mock_send(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_a2a_client + mock_factory_class.return_value = mock_factory + + async for _ in agent._send_message("Hello"): + pass + + # Verify ClientFactory was created with config containing the auth client + mock_factory_class.assert_called_once() + call_args = mock_factory_class.call_args + created_config = call_args[0][0] + assert created_config.httpx_client is mock_auth_client + + +@pytest.mark.asyncio +async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): + """Test _send_message creates a fresh httpx client for each call when no factory provided.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): + # Consume the async iterator + async for _ in a2a_agent._send_message("Hello"): + pass + + # Verify httpx client was created with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_a2a_client_no_config_creates_managed_httpx(): + """Test that _get_a2a_client creates a managed httpx client when no config provided.""" + mock_card = MagicMock(spec=AgentCard) + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_card): + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify httpx client was created with agent timeout + mock_httpx_class.assert_called_once_with(timeout=600) + # Verify ClientFactory was called with streaming=True + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + + +# === Invoke/Stream Tests === + + @pytest.mark.asyncio async def test_invoke_async_success(a2a_agent, mock_agent_card): """Test successful async invocation.""" @@ -242,48 +585,7 @@ async def test_stream_async_no_prompt(a2a_agent): pass -@pytest.mark.asyncio -async def test_send_message_uses_provided_factory(mock_agent_card): - """Test _send_message uses provided factory instead of creating per-call client.""" - external_factory = MagicMock() - mock_a2a_client = MagicMock() - - async def mock_send_message(*args, **kwargs): - yield MagicMock() - - mock_a2a_client.send_message = mock_send_message - external_factory.create.return_value = mock_a2a_client - - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - - with patch.object(agent, "get_agent_card", return_value=mock_agent_card): - # Consume the async iterator - async for _ in agent._send_message("Hello"): - pass - - external_factory.create.assert_called_once_with(mock_agent_card) - - -@pytest.mark.asyncio -async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): - """Test _send_message creates a fresh httpx client for each call when no factory provided.""" - mock_response = Message( - message_id=uuid4().hex, - role=Role.agent, - parts=[Part(TextPart(kind="text", text="Response"))], - ) - - async def mock_send_message(*args, **kwargs): - yield mock_response - - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): - # Consume the async iterator - async for _ in a2a_agent._send_message("Hello"): - pass - - # Verify httpx client was created with timeout - mock_httpx_class.assert_called_once_with(timeout=300) +# === Complete Event Tests === def test_is_complete_event_message(a2a_agent): From bb7f1886460ab70b9422e09552669d21910a1e56 Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Mon, 13 Apr 2026 12:15:46 -0400 Subject: [PATCH 224/279] fix: clear leaked running loop in MCP client background thread (#2111) --- src/strands/tools/mcp/mcp_client.py | 3 ++ .../tools/mcp/test_mcp_client_contextvar.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 11ed9c75e..e81dc7130 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -835,6 +835,9 @@ def _background_task(self) -> None: This allows for a long-running event loop. """ self._log_debug_with_thread("setting up background task event loop") + # Clear any running-loop state leaked by OpenTelemetry's ThreadingInstrumentor, which wraps Thread.run() + # and can propagate the parent thread's event loop reference, causing run_until_complete() to fail. + asyncio._set_running_loop(None) self._background_thread_event_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._background_thread_event_loop) self._background_thread_event_loop.run_until_complete(self._async_background_thread()) diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py index 739796366..1770a050a 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -88,3 +88,39 @@ def capturing_background_task(self): ) # Verify it was indeed a different thread assert background_thread_value["thread_id"] != main_thread_id, "Background task should run in a different thread" + + +def test_mcp_client_clears_running_loop_in_background_thread(mock_transport, mock_session): + """Test that _background_task clears any leaked running event loop state. + + When OpenTelemetry's ThreadingInstrumentor is active, Thread.run() is wrapped to propagate + trace context, which can leak the parent thread's running event loop reference into child + threads. This causes "RuntimeError: Cannot run the event loop while another loop is running" + when the background thread calls run_until_complete(). + + This test simulates that scenario by setting a running loop before _background_task runs + and verifying it gets cleared. + """ + import asyncio + + cleared_running_loop = {} + + original_background_task = MCPClient._background_task + + def simulating_otel_leak_background_task(self): + # Simulate OTEL ThreadingInstrumentor leaking the parent's running loop + fake_loop = asyncio.new_event_loop() + asyncio._set_running_loop(fake_loop) # type: ignore[attr-defined] + + # Call the real _background_task — it should clear the leaked loop and succeed + try: + return original_background_task(self) + finally: + cleared_running_loop["success"] = True + fake_loop.close() + + with patch.object(MCPClient, "_background_task", simulating_otel_leak_background_task): + with MCPClient(mock_transport["transport_callable"]) as client: + assert client._background_thread is not None + + assert cleared_running_loop.get("success"), "_background_task should have run successfully despite leaked loop" From 0930ca6d602f193a4bd72776cadfe25e16056330 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 13 Apr 2026 16:50:33 -0400 Subject: [PATCH 225/279] feat(openai): plumb through cache tokens in metadata events (#2116) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- src/strands/models/openai.py | 17 +++++++---- tests/strands/models/test_openai.py | 46 ++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 73484e924..333f59c71 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -17,6 +17,7 @@ from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -546,13 +547,19 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: return {"messageStop": {"stopReason": "end_turn"}} case "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + return { "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, + "usage": usage_data, "metrics": { "latencyMs": 0, # TODO }, diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 747e1123a..7af39032c 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -818,7 +818,12 @@ def test_format_request_with_tool_choice_tool(model, messages, tool_specs, syste ( { "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + "data": unittest.mock.Mock( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=None, + ), }, { "metadata": { @@ -847,6 +852,45 @@ def test_format_chunk_unknown_type(model): model.format_chunk(event) +def test_format_chunk_metadata_with_cache_tokens(model): + """Test format_chunk for metadata with cache tokens present.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + + +def test_format_chunk_metadata_with_zero_cached_tokens(model): + """Test format_chunk for metadata when cached_tokens is 0.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 0 + mock_usage.prompt_tokens_details = mock_tokens_details + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + + @pytest.mark.asyncio async def test_stream(openai_client, model_id, model, agenerator, alist): mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) From 2b8140107ab945ae973160df8a6aa4c65f22518e Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 14 Apr 2026 12:27:13 -0400 Subject: [PATCH 226/279] feat(agent): add take_snapshot() and load_snapshot() methods (#1948) Co-authored-by: Mackenzie Zastrow --- AGENTS.md | 1 + src/strands/__init__.py | 2 + src/strands/agent/agent.py | 80 +++++ src/strands/types/__init__.py | 3 +- src/strands/types/_snapshot.py | 145 +++++++++ src/strands/types/exceptions.py | 6 + tests/strands/agent/test_snapshot.py | 453 +++++++++++++++++++++++++++ 7 files changed, 689 insertions(+), 1 deletion(-) create mode 100644 src/strands/types/_snapshot.py create mode 100644 tests/strands/agent/test_snapshot.py diff --git a/AGENTS.md b/AGENTS.md index a9a2a5044..3615e713a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -105,6 +105,7 @@ strands-agents/ │ │ ├── event_loop.py # Event loop types │ │ ├── json_dict.py # JSON dict utilities │ │ ├── collections.py # Collection types +│ │ ├── _snapshot.py # Snapshot types and helpers │ │ ├── _events.py # Internal event types │ │ ├── a2a.py # A2A protocol types │ │ └── models/ # Model-specific types diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..6625ac41f 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -6,6 +6,7 @@ from .event_loop._retry import ModelRetryStrategy from .plugins import Plugin from .tools.decorator import tool +from .types._snapshot import Snapshot from .types.tools import ToolContext from .vended_plugins.skills import AgentSkills, Skill @@ -18,6 +19,7 @@ "ModelRetryStrategy", "Plugin", "Skill", + "Snapshot", "tool", "ToolContext", "types", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 439471a84..37fa5fc00 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,6 +9,7 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import copy import logging import threading import warnings @@ -29,6 +30,13 @@ from ..event_loop._retry import ModelRetryStrategy from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content +from ..types._snapshot import ( + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + SnapshotField, + SnapshotPreset, + resolve_snapshot_fields, +) if TYPE_CHECKING: from ..tools import ToolProvider @@ -1103,6 +1111,78 @@ async def _append_messages(self, *messages: Message) -> None: self.messages.append(message) await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) + def take_snapshot( + self, + *, + preset: SnapshotPreset | None = None, + include: list[SnapshotField] | None = None, + exclude: list[SnapshotField] | None = None, + app_data: dict[str, Any] | None = None, + ) -> Snapshot: + """Capture current agent state as an in-memory snapshot. + + Args: + preset: Named preset of fields to capture. Currently only "session" is supported, + which captures messages, state, conversation_manager_state, and interrupt_state. + include: Additional fields to capture on top of the preset. + exclude: Fields to remove after applying preset and include. + app_data: Application-owned arbitrary JSON stored verbatim in the snapshot. + + Returns: + A Snapshot containing the captured agent state. + + Raises: + SnapshotException: If no fields are resolved or an invalid field name is provided. + """ + fields = resolve_snapshot_fields(preset=preset, include=include, exclude=exclude) + + data: dict[str, Any] = {} + if "messages" in fields: + data["messages"] = copy.deepcopy(self.messages) + if "state" in fields: + data["state"] = self.state.get() + if "conversation_manager_state" in fields: + data["conversation_manager_state"] = self.conversation_manager.get_state() + if "interrupt_state" in fields: + data["interrupt_state"] = self._interrupt_state.to_dict() + if "system_prompt" in fields: + # Store the content-block representation so round-trips preserve caching hints and + # other block-level metadata. + data["system_prompt"] = copy.deepcopy(self._system_prompt_content) + + return Snapshot( + scope="agent", + schema_version=SNAPSHOT_SCHEMA_VERSION, + data=data, + app_data=copy.deepcopy(app_data) if app_data else {}, + ) + + def load_snapshot(self, snapshot: Snapshot) -> None: + """Restore agent state from a previously captured snapshot. + + Only fields present in snapshot.data are restored; absent fields are left unchanged. + + Args: + snapshot: The snapshot to restore from. + + Raises: + SnapshotException: If snapshot.schema_version is not "1.0". + """ + snapshot.validate() + + data = snapshot.data + + if "messages" in data: + self.messages = copy.deepcopy(data["messages"]) + if "state" in data: + self.state = AgentState(data["state"]) + if "conversation_manager_state" in data: + self.conversation_manager.restore_from_session(data["conversation_manager_state"]) + if "interrupt_state" in data: + self._interrupt_state = _InterruptState.from_dict(data["interrupt_state"]) + if "system_prompt" in data: + self.system_prompt = copy.deepcopy(data["system_prompt"]) + def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/types/__init__.py b/src/strands/types/__init__.py index 7eef60cb4..60d6b3a17 100644 --- a/src/strands/types/__init__.py +++ b/src/strands/types/__init__.py @@ -1,5 +1,6 @@ """SDK type definitions.""" +from ._snapshot import Snapshot from .collections import PaginatedList -__all__ = ["PaginatedList"] +__all__ = ["PaginatedList", "Snapshot"] diff --git a/src/strands/types/_snapshot.py b/src/strands/types/_snapshot.py new file mode 100644 index 000000000..407b811f2 --- /dev/null +++ b/src/strands/types/_snapshot.py @@ -0,0 +1,145 @@ +"""Snapshot types, constants, and helpers for agent state capture.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal, TypedDict + +from .exceptions import SnapshotException + +SnapshotField = Literal["messages", "state", "conversation_manager_state", "interrupt_state", "system_prompt"] +SnapshotPreset = Literal["session"] +Scope = Literal["agent"] + +ALL_SNAPSHOT_FIELDS: tuple[SnapshotField, ...] = ( + "messages", + "state", + "conversation_manager_state", + "interrupt_state", + "system_prompt", +) + +VALID_SCOPES: tuple[Scope, ...] = ("agent",) + +SNAPSHOT_SCHEMA_VERSION = "1.0" + +SNAPSHOT_PRESETS: dict[str, tuple[SnapshotField, ...]] = { + "session": ("messages", "state", "conversation_manager_state", "interrupt_state"), +} + + +class TakeSnapshotOptions(TypedDict, total=False): + """Internal options for take_snapshot. Not exported publicly.""" + + preset: SnapshotPreset + include: list[SnapshotField] + exclude: list[SnapshotField] + app_data: dict[str, Any] + + +@dataclass +class Snapshot: + """Point-in-time capture of agent state as a versioned JSON-compatible object.""" + + scope: Scope + schema_version: str + data: dict[str, Any] + app_data: dict[str, Any] + created_at: str = field(default="") # ISO 8601 UTC; auto-filled if empty + + def __post_init__(self) -> None: + if not self.created_at: + self.created_at = _utc_now_iso() + + def validate(self) -> None: + """Validate that this snapshot can be loaded by the current SDK version. + + Raises: + SnapshotException: If schema_version is not "1.0" or scope is invalid. + """ + if self.schema_version != SNAPSHOT_SCHEMA_VERSION: + raise SnapshotException( + f"Unsupported snapshot schema version: {self.schema_version!r}. " + f"Current version: {SNAPSHOT_SCHEMA_VERSION}" + ) + if self.scope not in VALID_SCOPES: + raise SnapshotException(f"Invalid snapshot scope: {self.scope!r}. Valid scopes: {sorted(VALID_SCOPES)}") + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain JSON-compatible dict.""" + return { + "scope": self.scope, + "schema_version": self.schema_version, + "created_at": self.created_at, + "data": self.data, + "app_data": self.app_data, + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> Snapshot: + """Reconstruct a Snapshot from a dict produced by to_dict(). + + Raises: + SnapshotException: If schema_version is not "1.0". + """ + snapshot = cls( + scope=d.get("scope", "agent"), + schema_version=d.get("schema_version", ""), + created_at=d["created_at"], + data=d["data"], + app_data=d.get("app_data", {}), + ) + snapshot.validate() + return snapshot + + +def resolve_snapshot_fields( + *, + preset: SnapshotPreset | None = None, + include: list[SnapshotField] | None = None, + exclude: list[SnapshotField] | None = None, +) -> set[SnapshotField]: + """Resolve the set of fields to capture based on options. + + Applies: preset → include → exclude (in that order). + + Raises: + SnapshotException: If any field name is invalid or the resolved set is empty. + """ + valid = set(ALL_SNAPSHOT_FIELDS) + + # Validate include/exclude field names + for f in include or []: + if f not in valid: + raise SnapshotException(f"Invalid snapshot field: {f!r}. Valid fields: {sorted(valid)}") + for f in exclude or []: + if f not in valid: + raise SnapshotException(f"Invalid snapshot field: {f!r}. Valid fields: {sorted(valid)}") + + # Step 1: start with preset + if preset is not None: + fields: set[SnapshotField] = set(SNAPSHOT_PRESETS[preset]) + else: + fields = set() + + # Step 2: union with include + if include: + fields |= set(include) + + # Step 3: subtract exclude + if exclude: + fields -= set(exclude) + + if not fields: + raise SnapshotException( + "No snapshot fields resolved. Provide a preset or at least one field in 'include'. " + "Note: passing only 'exclude' without a preset or 'include' always results in an empty set." + ) + + return fields + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string ending in 'Z'.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1d1983abd..5db80a26e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -77,6 +77,12 @@ class SessionException(Exception): pass +class SnapshotException(Exception): + """Exception raised when snapshot operations fail (e.g., unsupported schema version).""" + + pass + + class ToolProviderException(Exception): """Exception raised when a tool provider fails to load or cleanup tools.""" diff --git a/tests/strands/agent/test_snapshot.py b/tests/strands/agent/test_snapshot.py new file mode 100644 index 000000000..50e83a484 --- /dev/null +++ b/tests/strands/agent/test_snapshot.py @@ -0,0 +1,453 @@ +"""Tests for _snapshot.py — Snapshot dataclass and resolve_snapshot_fields.""" + +import json +import re +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from strands import Agent +from strands.types._snapshot import ( + ALL_SNAPSHOT_FIELDS, + SNAPSHOT_PRESETS, + SNAPSHOT_SCHEMA_VERSION, + VALID_SCOPES, + Snapshot, + resolve_snapshot_fields, +) +from strands.types.exceptions import SnapshotException + +# Helpers + +ISO_8601_UTC_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?Z$") + + +def _make_snapshot(**kwargs: object) -> Snapshot: + defaults: dict[str, Any] = { + "scope": "agent", + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00.000000Z", + "data": {}, + "app_data": {}, + } + defaults.update(kwargs) + return Snapshot(**defaults) + + +def _make_agent(**kwargs) -> Agent: + """Create a minimal Agent with a mock model for testing.""" + mock_model = MagicMock() + mock_model.get_config.return_value = {} + return Agent(model=mock_model, callback_handler=None, **kwargs) + + +def test_snapshot_from_dict_bad_version_raises(): + d = {"schema_version": "99.0", "created_at": "2025-01-15T12:00:00Z", "data": {}, "app_data": {}} + with pytest.raises(SnapshotException, match="Unsupported snapshot schema version"): + Snapshot.from_dict(d) + + +def test_snapshot_to_dict_round_trip(): + s = _make_snapshot(data={"messages": []}, app_data={"x": 1}) + assert Snapshot.from_dict(s.to_dict()) == s + + +def test_resolve_snapshot_fields_invalid_include_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields(include=["not_a_field"]) # type: ignore[list-item] + + +def test_resolve_snapshot_fields_invalid_exclude_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields(preset="session", exclude=["not_a_field"]) # type: ignore[list-item] + + +def test_resolve_snapshot_fields_no_preset_no_include_raises(): + with pytest.raises(SnapshotException, match="No snapshot fields resolved"): + resolve_snapshot_fields() + + +def test_resolve_snapshot_fields_session_preset(): + assert resolve_snapshot_fields(preset="session") == set(SNAPSHOT_PRESETS["session"]) + + +def test_resolve_snapshot_fields_include_adds_to_preset(): + fields = resolve_snapshot_fields(preset="session", include=["system_prompt"]) + assert fields == set(SNAPSHOT_PRESETS["session"]) | {"system_prompt"} + + +def test_resolve_snapshot_fields_exclude_removes_from_preset(): + fields = resolve_snapshot_fields(preset="session", exclude=["messages"]) + assert "messages" not in fields + + +def test_resolve_snapshot_fields_all_excluded_raises(): + with pytest.raises(SnapshotException): + resolve_snapshot_fields(exclude=list(ALL_SNAPSHOT_FIELDS)) # type: ignore[list-item] + + +_ORDERING_CASES = [ + # (preset, include, exclude) + ("session", [], []), + ("session", ["system_prompt"], []), + ("session", [], ["messages"]), + ("session", ["system_prompt"], ["messages", "state"]), + (None, ["messages", "state"], []), + (None, list(ALL_SNAPSHOT_FIELDS), []), + (None, list(ALL_SNAPSHOT_FIELDS), ["system_prompt"]), + ("session", ["system_prompt"], list(SNAPSHOT_PRESETS["session"])), # exclude all preset → only system_prompt +] + + +@pytest.mark.parametrize("preset,include,exclude", _ORDERING_CASES) +def test_resolve_snapshot_fields_ordering(preset, include, exclude): + expected = (set(SNAPSHOT_PRESETS[preset] if preset else []) | set(include)) - set(exclude) + + if not expected: + with pytest.raises(SnapshotException): + resolve_snapshot_fields(preset=preset, include=include or None, exclude=exclude or None) + else: + assert resolve_snapshot_fields(preset=preset, include=include or None, exclude=exclude or None) == expected + + +_STRUCTURAL_CASES = [ + ([], {}, None), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}, "system prompt"), + ([{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], {}, None), + ([], {"num": 42, "flag": True}, "another prompt"), +] + + +@pytest.mark.parametrize("messages,state_dict,system_prompt", _STRUCTURAL_CASES) +def test_snapshot_structural_invariants(messages, state_dict, system_prompt): + agent = _make_agent(messages=messages, state=state_dict, system_prompt=system_prompt) + snapshot = agent.take_snapshot(preset="session") + + assert snapshot.schema_version == "1.0" + assert ISO_8601_UTC_RE.match(snapshot.created_at), f"created_at={snapshot.created_at!r} not ISO 8601 UTC" + assert isinstance(snapshot.data, dict) + assert isinstance(snapshot.app_data, dict) + for field in ("messages", "state", "conversation_manager_state", "interrupt_state"): + assert field in snapshot.data + assert "system_prompt" not in snapshot.data + + +_APP_DATA_CASES = [ + {"key": "value"}, + {"num": 42, "flag": True, "nothing": None}, + {"nested_str": "hello", "count": 0}, +] + + +@pytest.mark.parametrize("app_data", _APP_DATA_CASES) +def test_app_data_stored_verbatim(app_data): + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + assert snapshot.app_data == app_data + + +_ROUND_TRIP_AGENT_CASES = [ + ([], {}), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}), + ( + [{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], + {"num": 1, "flag": None}, + ), +] + + +@pytest.mark.parametrize("messages,state_dict", _ROUND_TRIP_AGENT_CASES) +def test_agent_state_round_trip(messages, state_dict): + agent = _make_agent(messages=messages, state=state_dict, system_prompt="original prompt") + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent(system_prompt="original prompt") + fresh_agent.load_snapshot(snapshot) + + assert fresh_agent.messages == messages + assert fresh_agent.state.get() == state_dict + assert fresh_agent.system_prompt == "original prompt" + assert fresh_agent.conversation_manager.get_state() == agent.conversation_manager.get_state() + assert fresh_agent._interrupt_state.to_dict() == agent._interrupt_state.to_dict() + + +@pytest.mark.parametrize("omitted_field", list(ALL_SNAPSHOT_FIELDS)) +def test_missing_fields_leave_agent_unchanged(omitted_field): + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "original"}]}], + state={"key": "original"}, + system_prompt="original prompt", + ) + + include_fields = [f for f in ALL_SNAPSHOT_FIELDS if f != omitted_field] + snapshot = agent.take_snapshot(include=include_fields) + # system_prompt field is stored under the key "system_prompt" in snapshot.data + data_key = "system_prompt" if omitted_field == "system_prompt" else omitted_field + assert data_key not in snapshot.data + + fresh_agent = _make_agent( + messages=list(agent.messages), + state=agent.state.get(), + system_prompt="original prompt", + ) + + if omitted_field == "messages": + before = list(fresh_agent.messages) + elif omitted_field == "state": + before = fresh_agent.state.get() + elif omitted_field == "system_prompt": + before = fresh_agent.system_prompt + elif omitted_field == "conversation_manager_state": + before = fresh_agent.conversation_manager.get_state() + elif omitted_field == "interrupt_state": + before = fresh_agent._interrupt_state.to_dict() + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + fresh_agent.load_snapshot(snapshot) + + if omitted_field == "messages": + assert fresh_agent.messages == before + elif omitted_field == "state": + assert fresh_agent.state.get() == before + elif omitted_field == "system_prompt": + assert fresh_agent.system_prompt == before + elif omitted_field == "conversation_manager_state": + assert fresh_agent.conversation_manager.get_state() == before + elif omitted_field == "interrupt_state": + assert fresh_agent._interrupt_state.to_dict() == before + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + +def test_snapshot_no_system_prompt_clears_target_agent_prompt(): + """Snapshot from agent with no system_prompt (field included) clears prompt on restore.""" + source_agent = _make_agent() # no system_prompt + snapshot = source_agent.take_snapshot(include=["system_prompt"]) + + assert "system_prompt" in snapshot.data + assert snapshot.data["system_prompt"] is None + + target_agent = _make_agent(system_prompt="existing prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt is None + + +def test_snapshot_without_system_prompt_field_preserves_target_agent_prompt(): + """Snapshot taken without system_prompt field does not override target agent's prompt.""" + source_agent = _make_agent(system_prompt="source prompt") + snapshot = source_agent.take_snapshot(include=["messages"]) # system_prompt field excluded + + assert "system_prompt" not in snapshot.data + + target_agent = _make_agent(system_prompt="target prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt == "target prompt" + + +def test_load_snapshot_messages_are_independent_copy(): + """Messages restored from a snapshot are a copy — mutating snapshot.data after load doesn't affect the agent.""" + agent = _make_agent(messages=[{"role": "user", "content": [{"text": "hello"}]}]) + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent() + fresh_agent.load_snapshot(snapshot) + + snapshot.data["messages"].append({"role": "user", "content": [{"text": "injected"}]}) + assert len(fresh_agent.messages) == 1 + + +def test_take_snapshot_messages_are_independent_copy(): + """Mutating agent messages after take_snapshot doesn't corrupt the snapshot.""" + msg = {"role": "user", "content": [{"text": "original"}]} + agent = _make_agent(messages=[msg]) + snapshot = agent.take_snapshot(preset="session") + + agent.messages[0]["content"][0]["text"] = "mutated" + assert snapshot.data["messages"][0]["content"][0]["text"] == "original" + + +def test_take_snapshot_app_data_is_independent_copy(): + """Mutating app_data after take_snapshot doesn't corrupt the snapshot.""" + app_data = {"key": "original"} + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + + app_data["key"] = "mutated" + assert snapshot.app_data["key"] == "original" + + +# Scope validation + + +def test_valid_scopes_constant_matches_scope_type(): + """VALID_SCOPES contains exactly the values from the Scope Literal type.""" + assert set(VALID_SCOPES) == {"agent"} + + +def test_snapshot_validate_accepts_valid_scopes(): + """validate() should not raise for each valid scope value.""" + for scope in VALID_SCOPES: + snap = _make_snapshot(scope=scope) + snap.validate() # should not raise + + +def test_snapshot_validate_rejects_invalid_scope(): + """validate() should raise SnapshotException for an unrecognised scope.""" + snap = _make_snapshot(scope="invalid_scope") + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + snap.validate() + + +def test_snapshot_from_dict_rejects_invalid_scope(): + """from_dict() calls validate(), so an invalid scope should raise.""" + d = { + "scope": "bad_scope", + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00Z", + "data": {}, + "app_data": {}, + } + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + Snapshot.from_dict(d) + + +def test_snapshot_from_dict_defaults_scope_to_agent(): + """from_dict() defaults scope to 'agent' when the key is missing.""" + d = { + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00Z", + "data": {}, + "app_data": {}, + } + snap = Snapshot.from_dict(d) + assert snap.scope == "agent" + + +def test_load_snapshot_rejects_invalid_scope(): + """Agent.load_snapshot() should reject a snapshot with an invalid scope.""" + agent = _make_agent() + snap = _make_snapshot(scope="unknown") + with pytest.raises(SnapshotException, match="Invalid snapshot scope"): + agent.load_snapshot(snap) + + +def test_take_snapshot_always_produces_agent_scope(): + """take_snapshot() should always set scope to 'agent'.""" + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session") + assert snapshot.scope == "agent" + + +# Individual field restore from a raw snapshot + + +def test_load_snapshot_restores_messages_only(): + """A snapshot containing only messages restores them on a fresh agent.""" + agent = _make_agent(messages=[{"role": "user", "content": [{"text": "existing"}]}]) + snap = _make_snapshot(data={"messages": [{"role": "user", "content": [{"text": "restored"}]}]}) + + agent.load_snapshot(snap) + + assert len(agent.messages) == 1 + assert agent.messages[0]["content"][0]["text"] == "restored" + + +def test_load_snapshot_restores_state_only(): + """A snapshot containing only state restores it on a fresh agent.""" + agent = _make_agent(state={"old": "value"}) + snap = _make_snapshot(data={"state": {"new_key": "new_value"}}) + + agent.load_snapshot(snap) + + assert agent.state.get() == {"new_key": "new_value"} + + +def test_load_snapshot_restores_system_prompt_only(): + """A snapshot containing only system_prompt restores it on a fresh agent.""" + agent = _make_agent(system_prompt="old prompt") + snap = _make_snapshot(data={"system_prompt": "restored prompt"}) + + agent.load_snapshot(snap) + + assert agent.system_prompt == "restored prompt" + + +def test_snapshot_json_string_round_trip(): + """Snapshot survives json.dumps / json.loads round-trip.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "hello"}]}], + state={"k": "v"}, + system_prompt="test prompt", + ) + snapshot = agent.take_snapshot(preset="session", include=["system_prompt"]) + + json_str = json.dumps(snapshot.to_dict()) + restored = Snapshot.from_dict(json.loads(json_str)) + + assert restored == snapshot + + +def test_snapshot_json_store_and_restore_into_new_agent(): + """Simulate persisting a snapshot as JSON and restoring into a new agent.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "test message"}]}], + state={"key": "value"}, + ) + snapshot = agent.take_snapshot(preset="session") + + stored = json.dumps(snapshot.to_dict()) + retrieved = Snapshot.from_dict(json.loads(stored)) + + new_agent = _make_agent() + new_agent.load_snapshot(retrieved) + + assert new_agent.messages == [{"role": "user", "content": [{"text": "test message"}]}] + assert new_agent.state.get() == {"key": "value"} + + +def test_snapshot_round_trip_with_tool_use_messages(): + """Snapshot preserves toolUse and toolResult content blocks through a round-trip.""" + tool_use_msg = { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-123", "name": "calculator", "input": {"op": "add"}}}], + } + tool_result_msg = { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool-123", "status": "success", "content": [{"text": "6"}]}}], + } + agent = _make_agent(messages=[tool_use_msg, tool_result_msg]) + snapshot = agent.take_snapshot(include=["messages"]) + + fresh_agent = _make_agent() + fresh_agent.load_snapshot(snapshot) + + assert fresh_agent.messages == [tool_use_msg, tool_result_msg] + + +def test_take_snapshot_exclude_removes_field_from_data(): + """Excluding a field from take_snapshot omits it from snapshot.data while keeping others.""" + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "hi"}]}], + state={"k": "v"}, + ) + snapshot = agent.take_snapshot(preset="session", exclude=["messages"]) + + assert "messages" not in snapshot.data + assert "state" in snapshot.data + assert "conversation_manager_state" in snapshot.data + assert "interrupt_state" in snapshot.data + + +def test_take_snapshot_system_prompt_is_independent_copy(): + """Mutating agent system_prompt after take_snapshot doesn't corrupt the snapshot.""" + agent = _make_agent(system_prompt="original prompt") + snapshot = agent.take_snapshot(include=["system_prompt"]) + + original_content = snapshot.data["system_prompt"] + agent.system_prompt = "mutated prompt" + assert snapshot.data["system_prompt"] == original_content + assert snapshot.data["system_prompt"] != agent._system_prompt_content From 09902bd97482eed0aa352d539d15f058b1c124ae Mon Sep 17 00:00:00 2001 From: Davide Gallitelli Date: Wed, 15 Apr 2026 22:37:30 +0700 Subject: [PATCH 227/279] feat(skills): support loading skills from URLs (#2091) Co-authored-by: Davide Gallitelli Co-authored-by: Claude Opus 4.6 (1M context) --- .../vended_plugins/skills/agent_skills.py | 17 +++- src/strands/vended_plugins/skills/skill.py | 53 ++++++++++- .../skills/test_agent_skills.py | 89 ++++++++++++++++++ .../vended_plugins/skills/test_skill.py | 90 ++++++++++++++++++- 4 files changed, 242 insertions(+), 7 deletions(-) diff --git a/src/strands/vended_plugins/skills/agent_skills.py b/src/strands/vended_plugins/skills/agent_skills.py index 5e42b9230..23217e81c 100644 --- a/src/strands/vended_plugins/skills/agent_skills.py +++ b/src/strands/vended_plugins/skills/agent_skills.py @@ -86,6 +86,7 @@ def __init__( - A ``str`` or ``Path`` to a skill directory (containing SKILL.md) - A ``str`` or ``Path`` to a parent directory (containing skill subdirectories) - A ``Skill`` dataclass instance + - An ``https://`` URL pointing directly to raw SKILL.md content state_key: Key used to store plugin state in ``agent.state``. max_resource_files: Maximum number of resource files to list in skill responses. strict: If True, raise on skill validation issues. If False (default), warn and load anyway. @@ -176,8 +177,9 @@ def set_available_skills(self, skills: SkillSources) -> None: """Set the available skills, replacing any existing ones. Each element can be a ``Skill`` instance, a ``str`` or ``Path`` to a - skill directory (containing SKILL.md), or a ``str`` or ``Path`` to a - parent directory containing skill subdirectories. + skill directory (containing SKILL.md), a ``str`` or ``Path`` to a + parent directory containing skill subdirectories, or an ``https://`` + URL pointing directly to raw SKILL.md content. Note: this does not persist state or deactivate skills on any agent. Active skill state is managed per-agent and will be reconciled on the @@ -284,7 +286,8 @@ def _resolve_skills(self, sources: list[SkillSource]) -> dict[str, Skill]: """Resolve a list of skill sources into Skill instances. Each source can be a Skill instance, a path to a skill directory, - or a path to a parent directory containing multiple skills. + a path to a parent directory containing multiple skills, or an + HTTPS URL pointing to a SKILL.md file. Args: sources: List of skill sources to resolve. @@ -299,6 +302,14 @@ def _resolve_skills(self, sources: list[SkillSource]) -> dict[str, Skill]: if source.name in resolved: logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", source.name) resolved[source.name] = source + elif isinstance(source, str) and source.startswith("https://"): + try: + skill = Skill.from_url(source, strict=self._strict) + if skill.name in resolved: + logger.warning("name=<%s> | duplicate skill name, overwriting previous skill", skill.name) + resolved[skill.name] = skill + except (RuntimeError, ValueError) as e: + logger.warning("url=<%s> | failed to load skill from URL: %s", source, e) else: path = Path(source).resolve() if not path.exists(): diff --git a/src/strands/vended_plugins/skills/skill.py b/src/strands/vended_plugins/skills/skill.py index 3e1b6bba5..a60c1cd6c 100644 --- a/src/strands/vended_plugins/skills/skill.py +++ b/src/strands/vended_plugins/skills/skill.py @@ -1,15 +1,17 @@ """Skill data model and loading utilities for AgentSkills.io skills. This module defines the Skill dataclass and provides classmethods for -discovering, parsing, and loading skills from the filesystem or raw content. -Skills are directories containing a SKILL.md file with YAML frontmatter -metadata and markdown instructions. +discovering, parsing, and loading skills from the filesystem, raw content, +or HTTPS URLs. Skills are directories containing a SKILL.md file with YAML +frontmatter metadata and markdown instructions. """ from __future__ import annotations import logging import re +import urllib.error +import urllib.request from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -222,6 +224,9 @@ class Skill: # Load all skills from a parent directory skills = Skill.from_directory("./skills/") + # From an HTTPS URL + skill = Skill.from_url("https://example.com/SKILL.md") + Attributes: name: Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). description: Human-readable description of what the skill does. @@ -333,6 +338,48 @@ def from_content(cls, content: str, *, strict: bool = False) -> Skill: return _build_skill_from_frontmatter(frontmatter, body) + @classmethod + def from_url(cls, url: str, *, strict: bool = False) -> Skill: + """Load a skill by fetching its SKILL.md content from an HTTPS URL. + + Fetches the raw SKILL.md content over HTTPS and parses it using + :meth:`from_content`. The URL must point directly to the raw + file content (not an HTML page). + + Example:: + + skill = Skill.from_url( + "https://raw.githubusercontent.com/org/repo/main/SKILL.md" + ) + + Args: + url: An ``https://`` URL pointing directly to raw SKILL.md content. + strict: If True, raise on any validation issue. If False (default), + warn and load anyway. + + Returns: + A Skill instance populated from the fetched SKILL.md content. + + Raises: + ValueError: If ``url`` is not an ``https://`` URL. + RuntimeError: If the SKILL.md content cannot be fetched. + """ + if not url.startswith("https://"): + raise ValueError(f"url=<{url}> | not a valid HTTPS URL") + + logger.info("url=<%s> | fetching skill content", url) + + try: + req = urllib.request.Request(url, headers={"User-Agent": "strands-agents-sdk"}) # noqa: S310 + with urllib.request.urlopen(req, timeout=30) as response: # noqa: S310 + content: str = response.read().decode("utf-8") + except urllib.error.HTTPError as e: + raise RuntimeError(f"url=<{url}> | HTTP {e.code}: {e.reason}") from e + except urllib.error.URLError as e: + raise RuntimeError(f"url=<{url}> | failed to fetch skill: {e.reason}") from e + + return cls.from_content(content, strict=strict) + @classmethod def from_directory(cls, skills_dir: str | Path, *, strict: bool = False) -> list[Skill]: """Load all skills from a parent directory containing skill subdirectories. diff --git a/tests/strands/vended_plugins/skills/test_agent_skills.py b/tests/strands/vended_plugins/skills/test_agent_skills.py index 52802a6c1..db82355a9 100644 --- a/tests/strands/vended_plugins/skills/test_agent_skills.py +++ b/tests/strands/vended_plugins/skills/test_agent_skills.py @@ -661,6 +661,95 @@ def test_resolve_nonexistent_path(self, tmp_path): assert len(plugin._skills) == 0 +class TestResolveUrlSkills: + """Tests for _resolve_skills with URL sources.""" + + _SKILL_MODULE = "strands.vended_plugins.skills.skill" + _SAMPLE_CONTENT = "---\nname: url-skill\ndescription: A URL skill\n---\n# Instructions\n" + + def _mock_urlopen(self, content): + """Create a mock urlopen context manager returning the given content.""" + mock_response = MagicMock() + mock_response.read.return_value = content.encode("utf-8") + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response + + def test_resolve_url_source(self): + """Test resolving a URL string as a skill source.""" + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(self._SAMPLE_CONTENT) + ): + plugin = AgentSkills(skills=["https://example.com/SKILL.md"]) + + assert len(plugin.get_available_skills()) == 1 + assert plugin.get_available_skills()[0].name == "url-skill" + + def test_resolve_mixed_url_and_local(self, tmp_path): + """Test resolving a mix of URL and local filesystem sources.""" + from unittest.mock import patch + + _make_skill_dir(tmp_path, "local-skill") + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(self._SAMPLE_CONTENT) + ): + plugin = AgentSkills( + skills=[ + "https://example.com/SKILL.md", + str(tmp_path / "local-skill"), + ] + ) + + assert len(plugin.get_available_skills()) == 2 + names = {s.name for s in plugin.get_available_skills()} + assert names == {"url-skill", "local-skill"} + + def test_resolve_url_failure_skips_gracefully(self, caplog): + """Test that a failed URL fetch is skipped with a warning.""" + import logging + import urllib.error + from unittest.mock import patch + + with ( + patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.HTTPError( + url="https://example.com", code=404, msg="Not Found", hdrs=None, fp=None + ), + ), + caplog.at_level(logging.WARNING), + ): + plugin = AgentSkills(skills=["https://example.com/broken/SKILL.md"]) + + assert len(plugin.get_available_skills()) == 0 + assert "failed to load skill from URL" in caplog.text + + def test_resolve_duplicate_url_skills_warns(self, caplog): + """Test that duplicate skill names from URLs log a warning.""" + import logging + from unittest.mock import patch + + with ( + patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + return_value=self._mock_urlopen(self._SAMPLE_CONTENT), + ), + caplog.at_level(logging.WARNING), + ): + plugin = AgentSkills( + skills=[ + "https://example.com/a/SKILL.md", + "https://example.com/b/SKILL.md", + ] + ) + + assert len(plugin.get_available_skills()) == 1 + assert "duplicate skill name" in caplog.text + + class TestImports: """Tests for module imports.""" diff --git a/tests/strands/vended_plugins/skills/test_skill.py b/tests/strands/vended_plugins/skills/test_skill.py index 53d6f3507..cb67d71a2 100644 --- a/tests/strands/vended_plugins/skills/test_skill.py +++ b/tests/strands/vended_plugins/skills/test_skill.py @@ -551,11 +551,99 @@ def test_strict_mode(self): Skill.from_content(content, strict=True) +class TestSkillFromUrl: + """Tests for Skill.from_url.""" + + _SKILL_MODULE = "strands.vended_plugins.skills.skill" + _SAMPLE_CONTENT = "---\nname: my-skill\ndescription: A remote skill\n---\nRemote instructions.\n" + + def _mock_urlopen(self, content): + """Create a mock urlopen context manager returning the given content.""" + from unittest.mock import MagicMock + + mock_response = MagicMock() + mock_response.read.return_value = content.encode("utf-8") + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + return mock_response + + def test_from_url_returns_skill(self): + """Test loading a skill from a URL returns a single Skill.""" + from unittest.mock import patch + + mock_response = self._mock_urlopen(self._SAMPLE_CONTENT) + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=mock_response): + skill = Skill.from_url("https://raw.githubusercontent.com/org/repo/main/SKILL.md") + + assert isinstance(skill, Skill) + assert skill.name == "my-skill" + assert skill.description == "A remote skill" + assert "Remote instructions." in skill.instructions + assert skill.path is None + + def test_from_url_invalid_url_raises(self): + """Test that a non-HTTPS URL raises ValueError.""" + with pytest.raises(ValueError, match="not a valid HTTPS URL"): + Skill.from_url("./local-path") + + def test_from_url_http_rejected(self): + """Test that http:// URLs are rejected.""" + with pytest.raises(ValueError, match="not a valid HTTPS URL"): + Skill.from_url("http://example.com/SKILL.md") + + def test_from_url_http_error_raises(self): + """Test that HTTP errors propagate as RuntimeError.""" + import urllib.error + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.HTTPError( + url="https://example.com", code=404, msg="Not Found", hdrs=None, fp=None + ), + ): + with pytest.raises(RuntimeError, match="HTTP 404"): + Skill.from_url("https://example.com/SKILL.md") + + def test_from_url_network_error_raises(self): + """Test that network errors propagate as RuntimeError.""" + import urllib.error + from unittest.mock import patch + + with patch( + f"{self._SKILL_MODULE}.urllib.request.urlopen", + side_effect=urllib.error.URLError("Connection refused"), + ): + with pytest.raises(RuntimeError, match="failed to fetch"): + Skill.from_url("https://example.com/SKILL.md") + + def test_from_url_strict_mode(self): + """Test that strict mode is forwarded to from_content.""" + from unittest.mock import patch + + bad_content = "---\nname: BAD_NAME\ndescription: Bad\n---\nBody." + + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(bad_content)): + with pytest.raises(ValueError): + Skill.from_url("https://example.com/SKILL.md", strict=True) + + def test_from_url_invalid_content_raises(self): + """Test that non-SKILL.md content (e.g. HTML page) raises ValueError.""" + from unittest.mock import patch + + html_content = "Not a SKILL.md" + + with patch(f"{self._SKILL_MODULE}.urllib.request.urlopen", return_value=self._mock_urlopen(html_content)): + with pytest.raises(ValueError, match="frontmatter"): + Skill.from_url("https://example.com/SKILL.md") + + class TestSkillClassmethods: """Tests for Skill classmethod existence.""" def test_skill_classmethods_exist(self): - """Test that Skill has from_file, from_content, and from_directory classmethods.""" + """Test that Skill has from_file, from_content, from_directory, and from_url classmethods.""" assert callable(getattr(Skill, "from_file", None)) assert callable(getattr(Skill, "from_content", None)) assert callable(getattr(Skill, "from_directory", None)) + assert callable(getattr(Skill, "from_url", None)) From dd7a7d978fc0857ece3edd7b96a75dcdc0bd64c6 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:52:06 -0400 Subject: [PATCH 228/279] feat: add metadata field to messages for stateful context tracking (#2125) --- src/strands/agent/agent.py | 2 +- .../_recover_message_on_max_tokens_reached.py | 5 +- src/strands/event_loop/event_loop.py | 7 + src/strands/event_loop/streaming.py | 3 + src/strands/telemetry/tracer.py | 4 +- src/strands/types/content.py | 30 +++- .../strands/agent/hooks/test_agent_events.py | 13 +- tests/strands/agent/test_agent.py | 16 +- .../strands/agent/test_agent_cancellation.py | 3 +- tests/strands/agent/test_agent_hooks.py | 6 +- tests/strands/event_loop/test_event_loop.py | 17 ++- .../event_loop/test_event_loop_metadata.py | 141 ++++++++++++++++++ ...t_recover_message_on_max_tokens_reached.py | 28 ++++ .../tools/mcp/test_mcp_client_tasks.py | 8 +- tests/strands/types/test_message_metadata.py | 37 +++++ 15 files changed, 288 insertions(+), 32 deletions(-) create mode 100644 tests/strands/event_loop/test_event_loop_metadata.py create mode 100644 tests/strands/types/test_message_metadata.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 37fa5fc00..e8ea3c9bc 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -1025,7 +1025,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: # Check if all item in input list are dictionaries elif all(isinstance(item, dict) for item in prompt): # Check if all items are messages - if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt): + if all(all(key in item for key in Message.__required_keys__) for item in prompt): # Messages input - add all messages to conversation messages = cast(Messages, prompt) diff --git a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py index ab6fb4abe..dc073ba07 100644 --- a/src/strands/event_loop/_recover_message_on_max_tokens_reached.py +++ b/src/strands/event_loop/_recover_message_on_max_tokens_reached.py @@ -68,4 +68,7 @@ def recover_message_on_max_tokens_reached(message: Message) -> Message: } ) - return {"content": valid_content, "role": message["role"]} + recovered: Message = {"content": valid_content, "role": message["role"]} + if "metadata" in message: + recovered["metadata"] = message["metadata"] + return recovered diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b4af16058..bf1cc7a84 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -354,6 +354,13 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) + # Attach metadata to the assistant message immediately so it's + # available to all downstream consumers (hooks, events, state). + message["metadata"] = { + "usage": usage, + "metrics": metrics, + } + after_model_call_event = AfterModelCallEvent( agent=agent, invocation_state=invocation_state, diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 0a1161135..76eda48bf 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -488,6 +488,9 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = _normalize_messages(messages) + # Whitelist only role and content before sending to the model provider. + # This ensures metadata (and any future non-model fields) never leak to providers. + messages = [Message(role=msg["role"], content=msg["content"]) for msg in messages] start_time = time.time() chunks = model.stream( diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d5d399f95..37c16d3ae 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -527,9 +527,7 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: dict[str, AttributeValue] = self._get_common_attributes( - operation_name="execute_event_loop_cycle" - ) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_event_loop_cycle") attributes["event_loop.cycle_id"] = event_loop_cycle_id if custom_trace_attributes: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 2b0714bee..8db1d1d98 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,11 +6,12 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal +from typing import Any, Literal from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock +from .event_loop import Metrics, Usage from .media import DocumentContent, ImageContent, VideoContent from .tools import ToolResult, ToolUse @@ -177,17 +178,44 @@ class ContentBlockStop(TypedDict): """ +class MessageMetadata(TypedDict, total=False): + """Optional metadata attached to a message. + + Not sent to model providers — explicitly stripped before model calls. + Persisted alongside the message in session storage. + + Attributes: + usage: Token usage information from the model response. + metrics: Performance metrics from the model response. + custom: Arbitrary user/framework metadata (e.g. compression provenance). + """ + + usage: Usage + metrics: Metrics + custom: dict[str, Any] + + class Message(TypedDict): """A message in a conversation with the agent. Attributes: content: The message content. role: The role of the message sender. + metadata: Optional metadata, stripped before model calls. """ content: list[ContentBlock] role: Role + metadata: NotRequired[MessageMetadata] Messages = list[Message] """A list of messages representing a conversation.""" + + +def get_message_metadata(message: Message) -> MessageMetadata: + """Get metadata for a message, returning empty dict if not present. + + Individual fields (usage, metrics, custom) may not be present. Use .get() to safely access them. + """ + return message.get("metadata", {}) diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 02c367ccc..1f09579b0 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -147,6 +147,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "normal_tool", "toolUseId": "123"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -205,6 +206,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "async_tool", "toolUseId": "1234"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -263,6 +265,7 @@ async def test_stream_e2e_success(alist): {"toolUse": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -307,11 +310,11 @@ async def test_stream_e2e_success(alist): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "end_turn"}}}, - {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="end_turn", - message={"content": [{"text": "I invoked the tools!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked the tools!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -371,11 +374,11 @@ async def test_stream_e2e_throttle_and_redact(alist, mock_sleep): }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "guardrail_intervened"}}}, - {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}}, + {"message": {"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}}, { "result": AgentResult( stop_reason="guardrail_intervened", - message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant"}, + message={"content": [{"text": "INPUT BLOCKED!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, state={}, ), @@ -442,6 +445,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, } }, { @@ -453,6 +457,7 @@ async def test_stream_e2e_reasoning_redacted_content(alist): {"text": "Response with redacted reasoning"}, ], "role": "assistant", + "metadata": ANY, }, metrics=ANY, state={}, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0057c50a3..1e27274a1 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -336,7 +336,7 @@ def test_agent__call__( "stop_reason": result.stop_reason, } exp_result = { - "message": {"content": [{"text": "test text"}], "role": "assistant"}, + "message": {"content": [{"text": "test text"}], "role": "assistant", "metadata": unittest.mock.ANY}, "state": {}, "stop_reason": "end_turn", } @@ -781,6 +781,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, ), unittest.mock.call( @@ -793,6 +794,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): {"reasoningContent": {"reasoningText": {"text": "value", "signature": "value"}}}, {"text": "value"}, ], + "metadata": unittest.mock.ANY, }, metrics=unittest.mock.ANY, state={}, @@ -817,7 +819,7 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator): result = agent("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -837,7 +839,7 @@ async def test_agent_invoke_async(mock_model, agent, agenerator): result = await agent.invoke_async("test") tru_message = result.message - exp_message = {"content": [{"text": "abc"}], "role": "assistant"} + exp_message = {"content": [{"text": "abc"}], "role": "assistant", "metadata": unittest.mock.ANY} assert tru_message == exp_message @@ -1128,7 +1130,7 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali tru_message = agent.messages exp_message = [ {"content": prompt, "role": "user"}, - {"content": [{"text": "I see text and an image"}], "role": "assistant"}, + {"content": [{"text": "I see text and an image"}], "role": "assistant", "metadata": unittest.mock.ANY}, ] assert tru_message == exp_message @@ -1966,7 +1968,11 @@ def shell(command: str): } # And that it continued to the LLM call - assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"} + assert agent.messages[-1] == { + "content": [{"text": "I invoked a tool!"}], + "role": "assistant", + "metadata": unittest.mock.ANY, + } def test_agent_string_system_prompt(): diff --git a/tests/strands/agent/test_agent_cancellation.py b/tests/strands/agent/test_agent_cancellation.py index 6af153f4a..756e96485 100644 --- a/tests/strands/agent/test_agent_cancellation.py +++ b/tests/strands/agent/test_agent_cancellation.py @@ -2,6 +2,7 @@ import asyncio import threading +from unittest.mock import ANY import pytest @@ -31,7 +32,7 @@ async def test_agent_cancel_before_invocation(): result = await agent.invoke_async("Hello") assert result.stop_reason == "cancelled" - assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}]} + assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY} @pytest.mark.asyncio diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 3a40d69a8..2c61ee966 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -173,6 +173,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -199,7 +200,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, @@ -246,6 +247,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m message={ "content": [{"toolUse": tool_use}], "role": "assistant", + "metadata": ANY, }, stop_reason="tool_use", ), @@ -272,7 +274,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message=mock_model.agent_responses[1], + message={"role": "assistant", "content": [{"text": "I invoked a tool!"}], "metadata": ANY}, stop_reason="end_turn", ), exception=None, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index f91f7c2af..871371f5f 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -193,7 +193,7 @@ async def test_event_loop_cycle_text_response( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -225,7 +225,7 @@ async def test_event_loop_cycle_text_response_throttling( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -264,7 +264,7 @@ async def test_event_loop_cycle_exponential_backoff( # Verify the final response assert tru_stop_reason == "end_turn" - assert tru_message == {"role": "assistant", "content": [{"text": "test text"}]} + assert tru_message == {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} assert tru_request_state == {} # Verify that sleep was called with increasing delays @@ -354,7 +354,7 @@ async def test_event_loop_cycle_tool_result( tru_stop_reason, tru_message, _, tru_request_state, _, _ = events[-1]["stop"] exp_stop_reason = "end_turn" - exp_message = {"role": "assistant", "content": [{"text": "test text"}]} + exp_message = {"role": "assistant", "content": [{"text": "test text"}], "metadata": ANY} exp_request_state = {} assert tru_stop_reason == exp_stop_reason and tru_message == exp_message and tru_request_state == exp_request_state @@ -389,7 +389,6 @@ async def test_event_loop_cycle_tool_result( }, ], }, - {"role": "assistant", "content": [{"text": "test text"}]}, ], tool_registry.get_all_tool_specs(), "p1", @@ -484,6 +483,7 @@ async def test_event_loop_cycle_stop( } } ], + "metadata": ANY, } exp_request_state = {"stop_event_loop": True} @@ -946,14 +946,14 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, agent=agent, invocation_state=ANY, stop_response=AfterModelCallEvent.ModelStopResponse( - message={"content": [{"text": "test text"}], "role": "assistant"}, stop_reason="end_turn" + message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY}, stop_reason="end_turn" ), exception=None, ) # Final message assert next(events) == MessageAddedEvent( - agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} + agent=agent, message={"content": [{"text": "test text"}], "role": "assistant", "metadata": ANY} ) @@ -997,6 +997,7 @@ def interrupt_callback(event): }, ], "role": "assistant", + "metadata": ANY, }, }, "interrupts": { @@ -1131,7 +1132,7 @@ async def test_invalid_tool_names_adds_tool_uses(agent, model, alist): # ensure that we got end_turn and not tool_use assert events[-1] == EventLoopStopEvent( stop_reason="end_turn", - message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"}, + message={"content": [{"text": "I invoked a tool!"}], "role": "assistant", "metadata": ANY}, metrics=ANY, request_state={}, ) diff --git a/tests/strands/event_loop/test_event_loop_metadata.py b/tests/strands/event_loop/test_event_loop_metadata.py new file mode 100644 index 000000000..e6fe97f39 --- /dev/null +++ b/tests/strands/event_loop/test_event_loop_metadata.py @@ -0,0 +1,141 @@ +"""Tests for metadata population on assistant messages in the event loop.""" + +import threading +import unittest.mock + +import pytest + +import strands +import strands.event_loop.event_loop +from strands import Agent +from strands.event_loop._retry import ModelRetryStrategy +from strands.hooks import HookRegistry +from strands.interrupt import _InterruptState +from strands.telemetry.metrics import EventLoopMetrics +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry + + +@pytest.fixture +def model(): + return unittest.mock.Mock() + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +@pytest.fixture +def hook_registry(): + registry = HookRegistry() + retry_strategy = ModelRetryStrategy() + retry_strategy.register_hooks(registry) + return registry + + +@pytest.fixture +def tool_registry(): + return ToolRegistry() + + +@pytest.fixture +def agent(model, messages, tool_registry, hook_registry): + mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent + mock.config.cache_points = [] + mock.model = model + mock.system_prompt = "test" + mock.messages = messages + mock.tool_registry = tool_registry + mock.thread_pool = None + mock.event_loop_metrics = EventLoopMetrics() + mock.event_loop_metrics.reset_usage_metrics() + mock.hooks = hook_registry + mock.tool_executor = SequentialToolExecutor() + mock._interrupt_state = _InterruptState() + mock._cancel_signal = threading.Event() + mock.trace_attributes = {} + mock.retry_strategy = ModelRetryStrategy() + return mock + + +@pytest.mark.asyncio +async def test_metadata_populated_on_assistant_message(agent, model, agenerator, alist): + """After a model response, the assistant message should have metadata with usage and metrics.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + { + "metadata": { + "usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}, + "metrics": {"latencyMs": 200}, + } + }, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # The assistant message should be in agent.messages + assistant_msg = agent.messages[-1] + assert assistant_msg["role"] == "assistant" + assert "metadata" in assistant_msg + + meta = assistant_msg["metadata"] + assert meta["usage"]["inputTokens"] == 42 + assert meta["usage"]["outputTokens"] == 10 + assert meta["usage"]["totalTokens"] == 52 + assert meta["metrics"]["latencyMs"] == 200 + + +@pytest.mark.asyncio +async def test_metadata_has_default_usage_when_no_metadata_event(agent, model, agenerator, alist): + """When no metadata event is in the stream, metadata should still be set with defaults.""" + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + assistant_msg = agent.messages[-1] + assert "metadata" in assistant_msg + assert assistant_msg["metadata"]["usage"]["inputTokens"] == 0 + assert assistant_msg["metadata"]["usage"]["outputTokens"] == 0 + assert assistant_msg["metadata"]["metrics"]["latencyMs"] == 0 + + +@pytest.mark.asyncio +async def test_metadata_stripped_before_model_call(agent, model, agenerator, alist): + """Metadata from previous messages should be stripped before sending to the model.""" + # Pre-populate a message with metadata (simulating a previous turn) + agent.messages.append( + { + "role": "assistant", + "content": [{"text": "previous response"}], + "metadata": {"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}}, + } + ) + agent.messages.append({"role": "user", "content": [{"text": "follow up"}]}) + + model.stream.return_value = agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "response"}}}, + {"contentBlockStop": {}}, + ] + ) + + stream = strands.event_loop.event_loop.event_loop_cycle(agent=agent, invocation_state={}) + await alist(stream) + + # Verify that messages passed to model.stream() have no metadata key + call_args = model.stream.call_args + messages_sent = call_args[0][0] + for msg in messages_sent: + assert "metadata" not in msg, f"metadata leaked to model: {msg}" diff --git a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py index 402e90966..6dff0fc29 100644 --- a/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py +++ b/tests/strands/event_loop/test_recover_message_on_max_tokens_reached.py @@ -224,6 +224,34 @@ def test_recover_message_on_max_tokens_reached_multiple_incomplete_tools(): assert "incomplete due to maximum token limits" in result["content"][2]["text"] +def test_recover_message_on_max_tokens_reached_preserves_metadata(): + """Test that metadata is preserved through recovery.""" + message: Message = { + "role": "assistant", + "content": [ + {"toolUse": {"name": "calculator", "input": {}, "toolUseId": "123"}}, + ], + "metadata": {"usage": {"inputTokens": 42, "outputTokens": 10, "totalTokens": 52}}, + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" in result + assert result["metadata"]["usage"]["inputTokens"] == 42 + + +def test_recover_message_on_max_tokens_reached_without_metadata(): + """Test that recovery works fine when no metadata is present.""" + message: Message = { + "role": "assistant", + "content": [{"text": "some text"}], + } + + result = recover_message_on_max_tokens_reached(message) + + assert "metadata" not in result + + def test_recover_message_on_max_tokens_reached_preserves_user_role(): """Test that the function preserves the original message role.""" incomplete_message: Message = { diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index c21db9e28..d566ac6f5 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -251,9 +251,7 @@ def test_call_tool_sync_forwards_meta_to_task(self, mock_transport, mock_session with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() - client.call_tool_sync( - tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta - ) + client.call_tool_sync(tool_use_id="test-id", name="meta_tool", arguments={"param": "value"}, meta=meta) experimental.call_tool_as_task.assert_called_once() call_kwargs = experimental.call_tool_as_task.call_args @@ -281,9 +279,7 @@ def test_call_tool_sync_forwards_none_meta_to_task(self, mock_transport, mock_se with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: client.list_tools_sync() - client.call_tool_sync( - tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"} - ) + client.call_tool_sync(tool_use_id="test-id", name="no_meta_tool", arguments={"param": "value"}) experimental.call_tool_as_task.assert_called_once() call_kwargs = experimental.call_tool_as_task.call_args diff --git a/tests/strands/types/test_message_metadata.py b/tests/strands/types/test_message_metadata.py new file mode 100644 index 000000000..a7f93f690 --- /dev/null +++ b/tests/strands/types/test_message_metadata.py @@ -0,0 +1,37 @@ +"""Tests for MessageMetadata and get_message_metadata.""" + +from strands.types.content import Message, MessageMetadata, get_message_metadata + + +def test_message_without_metadata(): + msg: Message = {"role": "assistant", "content": [{"text": "hello"}]} + assert get_message_metadata(msg) == {} + + +def test_message_with_metadata(): + meta: MessageMetadata = { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 100}, + } + msg: Message = {"role": "assistant", "content": [{"text": "hello"}], "metadata": meta} + assert get_message_metadata(msg) == meta + assert get_message_metadata(msg)["usage"]["inputTokens"] == 10 + + +def test_message_with_custom_metadata(): + meta: MessageMetadata = { + "custom": {"source": "summarization", "original_turns": [5, 6, 7]}, + } + msg: Message = {"role": "assistant", "content": [{"text": "summary"}], "metadata": meta} + result = get_message_metadata(msg) + assert result["custom"]["source"] == "summarization" + + +def test_metadata_does_not_affect_role_and_content(): + msg: Message = { + "role": "assistant", + "content": [{"text": "hello"}], + "metadata": {"usage": {"inputTokens": 1, "outputTokens": 1, "totalTokens": 2}}, + } + assert msg["role"] == "assistant" + assert msg["content"] == [{"text": "hello"}] From 117da677fa398fffbce834d7df5099047b31eb82 Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Wed, 15 Apr 2026 14:04:15 -0400 Subject: [PATCH 229/279] feat(bidi): support request_state stop_event_loop flag (#1954) Co-authored-by: agent-of-mkmeral --- src/strands/experimental/bidi/__init__.py | 6 +- src/strands/experimental/bidi/agent/loop.py | 28 +++- src/strands/experimental/bidi/io/text.py | 2 +- .../experimental/bidi/tools/__init__.py | 15 +- .../bidi/tools/stop_conversation.py | 20 ++- .../experimental/bidi/agent/test_loop.py | 157 +++++++++++++++++- 6 files changed, 215 insertions(+), 13 deletions(-) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 99fbacce1..8ce3ebc68 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -14,7 +14,7 @@ # Model interface (for custom implementations) from .models.model import BidiModel -# Built-in tools +# Built-in tools (deprecated - use strands_tools.stop instead) from .tools import stop_conversation # Event types - For type hints and event handling @@ -39,8 +39,6 @@ __all__ = [ # Main interface "BidiAgent", - # Built-in tools - "stop_conversation", # Input Event types "BidiTextInputEvent", "BidiAudioInputEvent", @@ -64,6 +62,8 @@ "ToolStreamEvent", # Model interface "BidiModel", + # Built-in tools (deprecated) + "stop_conversation", ] diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2b883cf73..79818ae7c 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -5,6 +5,7 @@ import asyncio import logging +import warnings from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -248,6 +249,10 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_results: list[ToolResult] = [] + # Ensure request_state exists for tools like strands_tools.stop + if "request_state" not in self._invocation_state: + self._invocation_state["request_state"] = {} + invocation_state: dict[str, Any] = { **self._invocation_state, "agent": self._agent, @@ -282,16 +287,29 @@ async def _run_tool(self, tool_use: ToolUse) -> None: await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) - # Check for stop_conversation before sending to model - if tool_use["name"] == "stop_conversation": - logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"]) + # Check for stop_event_loop flag (set by strands_tools.stop, stop_conversation, or any custom tool) + request_state = invocation_state.get("request_state", {}) + should_stop = request_state.get("stop_event_loop", False) + + # Backward compatibility: also check for stop_conversation by name (deprecated) + if not should_stop and tool_use["name"] == "stop_conversation": + warnings.warn( + "Stopping the event loop by tool name 'stop_conversation' is deprecated. " + "Use request_state['stop_event_loop'] = True instead.", + DeprecationWarning, + stacklevel=2, + ) + should_stop = True + + if should_stop: + logger.info("stop_event_loop= | stopping conversation") connection_id = getattr(self._agent.model, "_connection_id", "unknown") await self._event_queue.put( BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") ) - return # Skip the model send + return # Skip sending result to model - # Send result to model (all tools except stop_conversation) + # Send result to model await self.send(tool_result_event) except Exception as error: diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py index f575c5606..00d999818 100644 --- a/src/strands/experimental/bidi/io/text.py +++ b/src/strands/experimental/bidi/io/text.py @@ -42,7 +42,7 @@ async def __call__(self, event: BidiOutputEvent) -> None: elif isinstance(event, BidiConnectionCloseEvent): if event.reason == "user_request": - print("user requested connection close using the stop_conversation tool.") + print("user requested connection close using the stop tool.") logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) elif isinstance(event, BidiTranscriptStreamEvent): text = event["text"] diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py index c665dc65a..de67040de 100644 --- a/src/strands/experimental/bidi/tools/__init__.py +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -1,4 +1,17 @@ -"""Built-in tools for bidirectional agents.""" +"""Built-in tools for bidirectional agents. + +.. deprecated:: + The built-in ``stop_conversation`` tool is deprecated. Use ``strands_tools.stop`` or set + ``request_state["stop_event_loop"] = True`` in any custom tool instead. + +To stop a bidirectional conversation, use the standard ``stop`` tool from strands_tools:: + + from strands_tools import stop + agent = BidiAgent(tools=[stop, ...]) + +The stop tool sets ``request_state["stop_event_loop"] = True``, which signals the +BidiAgent to gracefully close the connection. +""" from .stop_conversation import stop_conversation diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py index 9c7e1c6cd..21b530552 100644 --- a/src/strands/experimental/bidi/tools/stop_conversation.py +++ b/src/strands/experimental/bidi/tools/stop_conversation.py @@ -1,4 +1,11 @@ -"""Tool to gracefully stop a bidirectional connection.""" +"""Tool to gracefully stop a bidirectional connection. + +.. deprecated:: + The ``stop_conversation`` tool is deprecated and will be removed in a future version. + Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in any custom tool instead. +""" + +import warnings from ....tools.decorator import tool @@ -7,10 +14,19 @@ def stop_conversation() -> str: """Stop the bidirectional conversation gracefully. + .. deprecated:: + Use ``strands_tools.stop`` or set ``request_state["stop_event_loop"] = True`` in a custom tool instead. + Use ONLY when user says "stop conversation" exactly. Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases. Returns: - Success message confirming the conversation will end + Success message confirming the conversation will end. """ + warnings.warn( + "stop_conversation is deprecated and will be removed in a future version. " + "Use strands_tools.stop or set request_state['stop_event_loop'] = True in any custom tool instead.", + DeprecationWarning, + stacklevel=2, + ) return "Ending conversation" diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index fac52658e..a8efd9a93 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -1,4 +1,5 @@ import unittest.mock +import warnings import pytest import pytest_asyncio @@ -6,7 +7,7 @@ from strands import tool from strands.experimental.bidi import BidiAgent from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError -from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent +from strands.experimental.bidi.types.events import BidiConnectionCloseEvent, BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -93,3 +94,157 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): assert tru_messages == exp_messages agent.model.send.assert_called_with(tool_result_event) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_initialized_for_tools(loop, agent, agenerator): + """Test that request_state is initialized in invocation_state before tool execution. + + This ensures request_state exists for tools that may need it via invocation_state, + even when invocation_state is not provided by the user. + """ + tool_use = {"toolUseId": "t2", "name": "time_tool", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + # Start without providing invocation_state + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify request_state was initialized in invocation_state + assert "request_state" in loop._invocation_state + assert isinstance(loop._invocation_state["request_state"], dict) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_stop_event_loop_flag(agent, agenerator): + """Test that the stop_event_loop flag in request_state gracefully closes the connection. + + This simulates a tool (like strands_tools.stop) setting the flag via invocation_state. + """ + # Use a tool that modifies invocation_state to set the stop flag + # We'll mock the tool executor to simulate this behavior + loop = agent._loop + + tool_use = {"toolUseId": "t3", "name": "time_tool", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + tool_result = {"toolUseId": "t3", "status": "success", "content": [{"text": "12:00"}]} + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + # Start with request_state that already has stop_event_loop=True + # This simulates a tool having set it during execution + await loop.start(invocation_state={"request_state": {"stop_event_loop": True}}) + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + + # Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close + assert len(tru_events) == 4 + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify connection close event was emitted + connection_close_event = tru_events[3] + assert isinstance(connection_close_event, BidiConnectionCloseEvent) + assert connection_close_event["reason"] == "user_request" + + # Verify model.send was NOT called (tool result not sent to model) + agent.model.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_stop_conversation_deprecated_but_works(loop, agent, agenerator): + """Test that stop_conversation tool still works but emits a deprecation warning. + + The stop_conversation tool is deprecated in favor of request_state["stop_event_loop"], + but should continue to work for backward compatibility via the name-based check. + """ + from strands.experimental.bidi.tools import stop_conversation + + agent.tool_registry.register_tool(stop_conversation) + + tool_use = {"toolUseId": "t5", "name": "stop_conversation", "input": {}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + await loop.start() + + tru_events = [] + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + async for event in loop.receive(): + tru_events.append(event) + + # Should receive: tool_use_event, tool_result_event, tool_result_message, connection_close + assert len(tru_events) == 4 + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + assert "Ending conversation" in tool_result_event.tool_result["content"][0]["text"] + + # Verify connection close event was emitted + connection_close_event = tru_events[3] + assert isinstance(connection_close_event, BidiConnectionCloseEvent) + assert connection_close_event["reason"] == "user_request" + + # Verify model.send was NOT called (tool result not sent to model) + agent.model.send.assert_not_called() + + # Verify deprecation warnings were emitted (from both the tool itself and the loop name check) + deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + assert any("stop_conversation" in str(w.message).lower() for w in deprecation_warnings) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_request_state_preserved_with_invocation_state(agent, agenerator): + """Test that existing invocation_state is preserved when request_state is initialized.""" + + @tool(name="check_invocation_state") + async def check_invocation_state(custom_key: str) -> str: + return f"custom_key: {custom_key}" + + agent.tool_registry.register_tool(check_invocation_state) + + tool_use = {"toolUseId": "t4", "name": "check_invocation_state", "input": {"custom_key": "from_state"}} + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + loop = agent._loop + # Start with custom invocation_state but no request_state + await loop.start(invocation_state={"custom_data": "preserved"}) + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + # Verify tool executed successfully + tool_result_event = tru_events[1] + assert isinstance(tool_result_event, ToolResultEvent) + assert tool_result_event.tool_result["status"] == "success" + + # Verify request_state was added without removing custom_data + assert "request_state" in loop._invocation_state + assert loop._invocation_state.get("custom_data") == "preserved" From 6697d12814b656f56e30ba7861c3b28302f20635 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Thu, 16 Apr 2026 12:45:42 -0400 Subject: [PATCH 230/279] fix: preserve Gemini thought_signature in LiteLLM multi-turn tool calls (#2129) Co-authored-by: giulio-leone --- src/strands/models/litellm.py | 77 ++++++++++++- tests/strands/models/test_litellm.py | 138 +++++++++++++++++++++++ tests_integ/conftest.py | 1 + tests_integ/models/test_model_litellm.py | 13 +++ 4 files changed, 226 insertions(+), 3 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index be5337f0d..36bdb5a05 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -19,12 +19,16 @@ from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import MetadataEvent, StreamEvent -from ..types.tools import ToolChoice, ToolSpec +from ..types.tools import ToolChoice, ToolSpec, ToolUse from ._validation import validate_config_keys from .openai import OpenAIModel logger = logging.getLogger(__name__) +# Separator used by LiteLLM to embed thought signatures inside tool call IDs. +# See: https://ai.google.dev/gemini-api/docs/thought-signatures +_THOUGHT_SIGNATURE_SEPARATOR = "__thought__" + T = TypeVar("T", bound=BaseModel) @@ -114,6 +118,61 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> return super().format_request_message_content(content) + @override + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: + """Format a LiteLLM compatible tool call, encoding thought signatures into the tool call ID. + + Gemini thinking models attach a thought_signature to each function call. LiteLLM's OpenAI-compatible + interface embeds this signature inside the tool call ID using the ``__thought__`` separator. When + ``reasoningSignature`` is present and the tool call ID does not already contain the separator, this + method encodes it so LiteLLM can reconstruct the Gemini-native format on the next request. + + Args: + tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + LiteLLM compatible tool call dict with thought signature encoded in the ID when present. + """ + tool_call = super().format_request_message_tool_call(tool_use, **kwargs) + + reasoning_signature = tool_use.get("reasoningSignature") + if reasoning_signature and _THOUGHT_SIGNATURE_SEPARATOR not in tool_call["id"]: + tool_call["id"] = f"{tool_call['id']}{_THOUGHT_SIGNATURE_SEPARATOR}{reasoning_signature}" + + return tool_call + + @staticmethod + def _extract_thought_signature(data: Any) -> str | None: + """Extract thought signature from a tool call event data. + + LiteLLM surfaces Gemini thought signatures in two ways: + + 1. ``provider_specific_fields.thought_signature`` — a structured field set by LiteLLM's Gemini response + transformer. Checked first as it doesn't depend on matching an internal string constant. + 2. ``__thought__`` separator encoded in the tool call ID. Used as fallback since it relies on a copy of + LiteLLM's internal ``THOUGHT_SIGNATURE_SEPARATOR`` constant. + + Args: + data: Tool call event data object. + + Returns: + The extracted thought signature, or None if not present. + """ + # Preferred: structured field that doesn't depend on matching an internal separator string + psf = getattr(data, "provider_specific_fields", None) or {} + if isinstance(psf, dict) and psf.get("thought_signature"): + return str(psf["thought_signature"]) + + # Fallback: extract from encoded ID (relies on hardcoded copy of LiteLLM's separator) + tool_call_id = getattr(data, "id", None) or "" + if isinstance(tool_call_id, str) and _THOUGHT_SIGNATURE_SEPARATOR in tool_call_id: + _, signature = tool_call_id.split(_THOUGHT_SIGNATURE_SEPARATOR, 1) + return signature + + return None + def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]: """Handle switching to a new content stream. @@ -200,8 +259,9 @@ def format_request_messages( def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format a LiteLLM response event into a standardized message chunk. - This method overrides OpenAI's format_chunk to handle the metadata case - with prompt caching support. All other chunk types use the parent implementation. + Extends OpenAI's format_chunk to: + 1. Handle metadata with prompt caching support. + 2. Extract thought signatures that LiteLLM embeds in tool call IDs for Gemini thinking models. Args: event: A response event from the LiteLLM model. @@ -237,6 +297,17 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: usage=usage_data, ) ) + + # Extract thought signature from tool call content_start events. + # The full encoded ID is kept in toolUseId so that tool result messages continue to match. + if event["chunk_type"] == "content_start" and event.get("data_type") == "tool": + signature = self._extract_thought_signature(event.get("data")) + chunk = super().format_chunk(event) + if signature: + tool_use_dict = cast(dict, chunk["contentBlockStart"]["start"]["toolUse"]) + tool_use_dict["reasoningSignature"] = signature + return chunk + # For all other cases, use the parent implementation return super().format_chunk(event) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 9bb0e09ca..d35a1806e 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -848,3 +848,141 @@ def test_format_request_messages_with_tool_calls_no_content(): }, ] assert tru_result == exp_result + + +# --- Thought Signature Tests --- + + +def test_format_chunk_tool_start_extracts_thought_signature_from_id(): + """Test that format_chunk extracts thought_signature from LiteLLM-encoded tool call ID.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "dGhpcy1pcy1hLXNpZw==" + # toolUseId keeps the full encoded string so tool result IDs match + assert tool_use["toolUseId"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_chunk_tool_start_extracts_thought_signature_from_provider_specific_fields(): + """Test that format_chunk extracts thought_signature from provider_specific_fields.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_abc123" # No __thought__ in ID + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.function.provider_specific_fields = None + mock_data.provider_specific_fields = {"thought_signature": "cHNmLXNpZw=="} + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["reasoningSignature"] == "cHNmLXNpZw==" + assert tool_use["toolUseId"] == "call_abc123" + + +def test_format_chunk_tool_start_no_thought_signature(): + """Test that format_chunk works normally when no thought_signature is present.""" + model = LiteLLMModel(model_id="test") + + mock_data = unittest.mock.Mock() + mock_data.id = "call_plain123" + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "get_weather" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + result = model.format_chunk(event) + + tool_use = result["contentBlockStart"]["start"]["toolUse"] + assert tool_use["toolUseId"] == "call_plain123" + assert "reasoningSignature" not in tool_use + + +def test_format_request_message_tool_call_encodes_thought_signature(): + """Test that format_request_message_tool_call encodes reasoningSignature into the tool call ID.""" + tool_use = { + "toolUseId": "call_abc123", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + assert result["function"]["name"] == "get_weather" + assert result["function"]["arguments"] == '{"city": "Seattle"}' + + +def test_format_request_message_tool_call_skips_encoding_when_already_present(): + """Test that format_request_message_tool_call does not double-encode the signature.""" + tool_use = { + "toolUseId": "call_abc123__thought__dGhpcy1pcy1hLXNpZw==", + "name": "get_weather", + "input": {"city": "Seattle"}, + "reasoningSignature": "dGhpcy1pcy1hLXNpZw==", + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + # Should NOT double-encode + assert result["id"] == "call_abc123__thought__dGhpcy1pcy1hLXNpZw==" + + +def test_format_request_message_tool_call_no_reasoning_signature(): + """Test that format_request_message_tool_call works normally without reasoningSignature.""" + tool_use = { + "toolUseId": "call_plain123", + "name": "get_weather", + "input": {"city": "Seattle"}, + } + + result = LiteLLMModel.format_request_message_tool_call(tool_use) + + assert result["id"] == "call_plain123" + assert "__thought__" not in result["id"] + + +def test_thought_signature_round_trip(): + """Test that thought signature is preserved through a full response -> internal -> request cycle.""" + model = LiteLLMModel(model_id="test") + signature = "R2VtaW5pVGhvdWdodFNpZw==" + tool_call_id = f"call_xyz789__thought__{signature}" + + # 1. Response path: format_chunk extracts the signature + mock_data = unittest.mock.Mock() + mock_data.id = tool_call_id + mock_data.function = unittest.mock.Mock() + mock_data.function.name = "current_time" + mock_data.provider_specific_fields = None + mock_data.function.provider_specific_fields = None + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_data} + chunk = model.format_chunk(event) + tool_use_data = chunk["contentBlockStart"]["start"]["toolUse"] + assert tool_use_data["reasoningSignature"] == signature + + # 2. Simulate internal storage (streaming layer stores reasoningSignature) + internal_tool_use = { + "toolUseId": tool_use_data["toolUseId"], + "name": tool_use_data["name"], + "input": {"timezone": "UTC"}, + "reasoningSignature": tool_use_data["reasoningSignature"], + } + + # 3. Request path: format_request_message_tool_call re-encodes the signature + tool_call = LiteLLMModel.format_request_message_tool_call(internal_tool_use) + assert "__thought__" in tool_call["id"] + assert signature in tool_call["id"] diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index 347b22a43..b7ae78ec3 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -202,6 +202,7 @@ def _load_api_keys_from_secrets_manager(): required_providers = { "ANTHROPIC_API_KEY", + "GOOGLE_API_KEY", "MISTRAL_API_KEY", "OPENAI_API_KEY", "WRITER_API_KEY", diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index eb0737e0f..b09983d73 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -1,3 +1,4 @@ +import os import unittest.mock from uuid import uuid4 @@ -277,3 +278,15 @@ async def test_cache_read_tokens_multi_turn(model): assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 + + +def test_gemini_thinking_model_tool_call(tools): + """Test that Gemini thinking models preserve thought_signature through multi-turn tool calls. + + Regression test for https://github.com/strands-agents/sdk-python/issues/1764 + """ + model = LiteLLMModel(model_id="gemini/gemini-2.5-flash", client_args={"api_key": os.environ.get("GOOGLE_API_KEY")}) + agent = Agent(model=model, tools=tools) + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + assert all(string in text for string in ["12:00", "sunny"]) From 8e96ea878aed7fee2b64a845bd02cd6648d40e76 Mon Sep 17 00:00:00 2001 From: ghhamel Date: Thu, 16 Apr 2026 14:49:57 -0400 Subject: [PATCH 231/279] fix(bedrock): normalize empty toolResult content arrays in _format_bedrock_messages (#2123) --- src/strands/models/bedrock.py | 9 +++- tests/strands/models/test_bedrock.py | 81 ++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bfb7b1ede..742cd82e9 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -601,8 +601,15 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html if "toolResult" in content: tool_result = content["toolResult"] + # Normalize empty toolResult content arrays. + # Some model providers (e.g., Nemotron) reject toolResult blocks with + # content: [] via the Converse API, while others (e.g., Claude) accept + # them. Replace empty content with a minimal text block to ensure + # cross-model compatibility. This follows the same pattern as the + # TypeScript SDK's _formatMessages in bedrock.ts. + tool_result_content_list = tool_result.get("content") or [{"text": ""}] formatted_content: list[dict[str, Any]] = [] - for tool_result_content in tool_result["content"]: + for tool_result_content in tool_result_content_list: if "json" in tool_result_content: # Handle json field since not in ContentBlock but valid in ToolResultContent formatted_content.append({"json": tool_result_content["json"]}) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index cd7016488..05f0fa92f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1607,6 +1607,87 @@ def test_format_request_cleans_tool_result_content_blocks(model, model_id): assert "status" not in tool_result +def test_format_request_message_content_normalizes_empty_tool_result_content(model, model_id): + """Test that _format_request_message_content replaces empty toolResult content with a minimal text block. + + Some model providers (e.g., Nemotron) reject toolResult blocks with content: [] via the + Converse API, while others (e.g., Claude) accept them. The SDK should normalize empty + content arrays to ensure cross-model compatibility. + + See: https://github.com/strands-agents/sdk-python/issues/2122 + """ + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"text": "Querying...\n"}, + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": []}}, + ], + }, + ] + + formatted_request = model._format_request(messages) + + tool_result = formatted_request["messages"][2]["content"][0]["toolResult"] + assert tool_result["content"] == [{"text": ""}], "Empty toolResult content should be normalized to [{'text': ''}]" + + +def test_format_request_message_content_does_not_mutate_empty_tool_result(model, model_id): + """Test that normalizing empty toolResult content does not mutate the original messages.""" + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": []}}, + ], + }, + ] + + original_content = messages[2]["content"][0]["toolResult"]["content"] + model._format_request(messages) + + assert original_content == [], "Original empty content list should not be mutated" + + +def test_format_request_message_content_preserves_nonempty_tool_result_content(model, model_id): + """Test that _format_request_message_content does not modify non-empty toolResult content.""" + messages = [ + {"role": "user", "content": [{"text": "List tables"}]}, + { + "role": "assistant", + "content": [ + {"text": "Querying...\n"}, + {"toolUse": {"toolUseId": "tool_001", "name": "run_query", "input": {"sql": "SELECT 1"}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "tool_001", "content": [{"text": "some result"}]}}, + ], + }, + ] + + formatted_request = model._format_request(messages) + + tool_result = formatted_request["messages"][2]["content"][0]["toolResult"] + assert tool_result["content"] == [{"text": "some result"}] + + def test_format_request_removes_status_field_when_configured(model, model_id): model.update_config(include_tool_result_status=False) From 4e3ad44195316f8e286deb07f9d30432740e2089 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Fri, 17 Apr 2026 15:13:56 -0400 Subject: [PATCH 232/279] fix(telemetry): remove force_flush in tracer (#2142) --- src/strands/telemetry/tracer.py | 6 ------ tests/strands/telemetry/test_tracer.py | 15 --------------- 2 files changed, 21 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 37c16d3ae..a422d3cbf 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -218,12 +218,6 @@ def _end_span( logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() - # Force flush to ensure spans are exported - if self.tracer_provider and hasattr(self.tracer_provider, "force_flush"): - try: - self.tracer_provider.force_flush() - except Exception as e: - logger.warning("error=<%s> | failed to force flush tracer provider", e) def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: """End a span with error status. diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 6b622bb3e..8af7b782e 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1221,21 +1221,6 @@ def test_end_span_with_exception_handling(mock_span): pytest.fail("_end_span should not raise exceptions") -def test_force_flush_with_error(mock_span, mock_get_tracer_provider): - """Test force flush with error handling.""" - # Setup the tracer with a provider that raises an exception on force_flush - tracer = Tracer() - - mock_tracer_provider = mock_get_tracer_provider.return_value - mock_tracer_provider.force_flush.side_effect = Exception("Force flush error") - - # Should not raise an exception - tracer._end_span(mock_span) - - # Verify force_flush was called - mock_tracer_provider.force_flush.assert_called_once() - - def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer() From 7b0337bc63dd945a9ae1ade15556a004992c11fd Mon Sep 17 00:00:00 2001 From: lufecadu Date: Tue, 21 Apr 2026 12:35:22 -0700 Subject: [PATCH 233/279] fix: add fallback trim point for tool-heavy conversations in SlidingWindowConversationManager (#2174) --- .../sliding_window_conversation_manager.py | 42 +++++++++++--- .../agent/test_conversation_manager.py | 58 +++++++++++++++++++ 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 94446380b..f91d7a538 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -192,9 +192,24 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A # 1. Starts with a user message (required by most model providers) # 2. Does not start with an orphaned toolResult # 3. Does not start with a toolUse unless its toolResult immediately follows + # Falls back to an assistant(toolUse) + user(toolResult) boundary if no plain user message exists. + # This is acceptable because providers treat a complete toolUse/toolResult pair as a valid + # conversation continuation, and without this fallback tool-heavy conversations cannot be trimmed. + fallback_trim_index = None + while trim_index < len(messages): - # Must start with a user message + # Prefer starting with a user message if messages[trim_index]["role"] != "user": + # Track first valid assistant(toolUse) + user(toolResult) pair as fallback + if ( + fallback_trim_index is None + and any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and messages[trim_index + 1]["role"] == "user" + and any("toolResult" in content for content in messages[trim_index + 1]["content"]) + ): + fallback_trim_index = trim_index + trim_index += 1 continue @@ -216,15 +231,24 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A else: break else: - # If we didn't find a valid trim_index - if e is not None: + # No plain user message found — use assistant+toolResult fallback if available + if fallback_trim_index is not None: + logger.debug( + "trim_index=<%s> | no plain user message trim point found, " + "falling back to assistant(toolUse) + user(toolResult) boundary", + fallback_trim_index, + ) + trim_index = fallback_trim_index + elif e is not None: raise ContextWindowOverflowException("Unable to trim conversation context!") from e - logger.warning( - "window_size=<%s>, message_count=<%s> | unable to trim conversation context, no valid trim point found", - self.window_size, - len(messages), - ) - return + else: + logger.warning( + "window_size=<%s>, message_count=<%s> | unable to trim conversation context, " + "no valid trim point found", + self.window_size, + len(messages), + ) + return # trim_index represents the number of messages being removed from the agents messages array self.removed_message_count += trim_index diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 6db9897f1..c8b9df1cf 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -230,6 +230,64 @@ def test_sliding_window_no_valid_trim_point_without_error_does_not_raise(): assert messages == original_messages +def test_sliding_window_tool_heavy_conversation_falls_back_to_tool_pair_boundary(): + """Tool-heavy conversations trim to assistant(toolUse) + user(toolResult) boundary.""" + manager = SlidingWindowConversationManager(window_size=4, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Review this PR"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "get_diff", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "diff"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "get_file", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "2", "content": [{"text": "file"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "3", "name": "get_tree", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "3", "content": [{"text": "tree"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Here is my review"}]}, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent, e=Exception("context window overflow")) + + # Should trim to first assistant(toolUse) + user(toolResult) pair after trim_index + # With 8 messages and window_size=4, trim_index starts at 4. First fallback at index 5 (toolUseId "3"). + assert len(messages) == 3 + assert messages[0]["role"] == "assistant" + assert messages[0]["content"][0]["toolUse"]["toolUseId"] == "3" + assert messages[1]["role"] == "user" + assert any("toolResult" in content for content in messages[1]["content"]) + + +def test_sliding_window_prefers_plain_user_message_over_tool_pair_fallback(): + """Plain user messages are preferred over assistant+toolResult fallback when both exist.""" + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "result"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Plain user message"}]}, + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + test_agent = Agent(messages=messages) + + manager.apply_management(test_agent) + + # Should prefer the plain user message, not the assistant+toolResult fallback + assert messages[0]["role"] == "user" + assert messages[0]["content"] == [{"text": "Plain user message"}] + + def test_sliding_window_conversation_manager_with_tool_results_truncated(): large_text = "A" * 300 + "B" * 300 + "C" * 300 manager = SlidingWindowConversationManager(1) From 724b59138dcad67aaf41fe9ab301b26575219aea Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Tue, 21 Apr 2026 17:06:39 -0400 Subject: [PATCH 234/279] feat: introduce checkpoint in experimental (#2181) --- AGENTS.md | 2 + src/strands/experimental/__init__.py | 4 +- .../experimental/checkpoint/__init__.py | 12 +++ .../experimental/checkpoint/checkpoint.py | 94 +++++++++++++++++++ src/strands/types/event_loop.py | 2 + .../experimental/checkpoint/__init__.py | 0 .../checkpoint/test_checkpoint.py | 53 +++++++++++ 7 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 src/strands/experimental/checkpoint/__init__.py create mode 100644 src/strands/experimental/checkpoint/checkpoint.py create mode 100644 tests/strands/experimental/checkpoint/__init__.py create mode 100644 tests/strands/experimental/checkpoint/test_checkpoint.py diff --git a/AGENTS.md b/AGENTS.md index 3615e713a..8835b45c8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -152,6 +152,8 @@ strands-agents/ │ │ │ ├── tools/ # Bidi tools │ │ │ ├── types/ # Bidi types │ │ │ └── _async/ # Async utilities +│ │ ├── checkpoint/ # Durable agent execution checkpoints +│ │ │ └── checkpoint.py # Checkpoint dataclass and serialization │ │ ├── hooks/ # Experimental hooks │ │ │ ├── events.py │ │ │ └── multiagent/ diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 3c1d0ee46..cbd9a713e 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,7 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ -from . import steering, tools +from . import checkpoint, steering, tools from .agent_config import config_to_agent -__all__ = ["config_to_agent", "tools", "steering"] +__all__ = ["checkpoint", "config_to_agent", "tools", "steering"] diff --git a/src/strands/experimental/checkpoint/__init__.py b/src/strands/experimental/checkpoint/__init__.py new file mode 100644 index 000000000..848cda6d6 --- /dev/null +++ b/src/strands/experimental/checkpoint/__init__.py @@ -0,0 +1,12 @@ +"""Experimental checkpoint types for durable agent execution. + +This module is experimental and subject to change in future revisions without notice. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. +""" + +from .checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint, CheckpointPosition + +__all__ = ["CHECKPOINT_SCHEMA_VERSION", "Checkpoint", "CheckpointPosition"] diff --git a/src/strands/experimental/checkpoint/checkpoint.py b/src/strands/experimental/checkpoint/checkpoint.py new file mode 100644 index 000000000..f37e403c9 --- /dev/null +++ b/src/strands/experimental/checkpoint/checkpoint.py @@ -0,0 +1,94 @@ +"""Checkpoint system for durable agent execution. + +Checkpoints enable crash-resilient agent workflows by capturing agent state at +cycle boundaries in the agent loop. A durability provider (e.g. Temporal) can +persist checkpoints and resume from them after failures. + +Two checkpoint positions per ReAct cycle: +- after_model: model call completed, tools not yet executed. +- after_tools: all tools executed, next model call pending. + +Per-tool granularity is handled by the ToolExecutor abstraction (e.g. +TemporalToolExecutor routes each tool to a separate Temporal activity). +The SDK checkpoint operates at cycle boundaries. + +User-facing pattern (same as interrupts): +- Pause via stop_reason="checkpoint" on AgentResult +- State via AgentResult.checkpoint field +- Resume via checkpointResume content block in next agent() call + +V0 Known Limitations: +- Metrics reset on each resume call. The caller is responsible for aggregating + metrics across a durable run. EventLoopMetrics reflects only the current call. +- OpenAIResponsesModel(stateful=True) is not supported. The server-side + response_id (_model_state) is not captured in the snapshot. +- When position is "after_tools", AgentResult.message is the assistant message + that requested the tools; tool results are in the snapshot messages. +- BeforeInvocationEvent and AfterInvocationEvent fire on every resume call, + same as interrupts. Hooks counting invocations will see each resume as a + separate invocation. +- Per-tool granularity within a cycle requires a custom ToolExecutor + (e.g. TemporalToolExecutor). +""" + +import logging +from dataclasses import asdict, dataclass, field +from typing import Any, Literal + +logger = logging.getLogger(__name__) + +CHECKPOINT_SCHEMA_VERSION = "1.0" + +CheckpointPosition = Literal["after_model", "after_tools"] + + +@dataclass +class Checkpoint: + """Pause point in the agent loop. Treat as opaque — pass back to resume. + + Attributes: + position: What just completed (after_model or after_tools). + cycle_index: Which ReAct loop cycle (0-based). + snapshot: Serialized agent state as a dict, produced by ``Snapshot.to_dict()``. + Stored as ``dict[str, Any]`` (not a ``Snapshot`` object) because checkpoints + must be JSON-serializable for cross-process persistence. The consumer + reconstructs via ``Snapshot.from_dict()`` on resume. + app_data: Application-level internal state data. The SDK does not read + or modify this. Applications can store arbitrary data needed across + checkpoint boundaries (e.g. session context, workflow metadata). + Separate from ``Snapshot.app_data`` which captures agent-state-level + data managed by the SDK. + schema_version: Rejects mismatches on resume across schema versions. + """ + + position: CheckpointPosition + cycle_index: int = 0 + snapshot: dict[str, Any] = field(default_factory=dict) + app_data: dict[str, Any] = field(default_factory=dict) + schema_version: str = field(init=False, default=CHECKPOINT_SCHEMA_VERSION) + + def to_dict(self) -> dict[str, Any]: + """Serialize for persistence.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Checkpoint": + """Reconstruct from a dict produced by to_dict(). + + Args: + data: Serialized checkpoint data. + + Raises: + ValueError: If schema_version doesn't match the current version. + """ + version = data.get("schema_version", "") + if version != CHECKPOINT_SCHEMA_VERSION: + raise ValueError( + f"Checkpoints with schema version {version!r} are not compatible " + f"with current version {CHECKPOINT_SCHEMA_VERSION}." + ) + known_keys = {k for k in cls.__dataclass_fields__ if k != "schema_version"} + unknown_keys = set(data.keys()) - known_keys - {"schema_version"} + if unknown_keys: + logger.warning("unknown_keys=<%s> | ignoring unknown fields in checkpoint data", unknown_keys) + return cls(**{k: v for k, v in data.items() if k in known_keys}) diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index fca141327..73d4e2bc0 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -38,6 +38,7 @@ class Metrics(TypedDict, total=False): StopReason = Literal[ "cancelled", + "checkpoint", "content_filtered", "end_turn", "guardrail_intervened", @@ -49,6 +50,7 @@ class Metrics(TypedDict, total=False): """Reason for the model ending its response generation. - "cancelled": Agent execution was cancelled via agent.cancel() +- "checkpoint": Agent paused for durable checkpoint persistence - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened diff --git a/tests/strands/experimental/checkpoint/__init__.py b/tests/strands/experimental/checkpoint/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/checkpoint/test_checkpoint.py b/tests/strands/experimental/checkpoint/test_checkpoint.py new file mode 100644 index 000000000..4435fb3db --- /dev/null +++ b/tests/strands/experimental/checkpoint/test_checkpoint.py @@ -0,0 +1,53 @@ +"""Tests for strands.experimental.checkpoint — Checkpoint serialization.""" + +import pytest + +from strands.experimental.checkpoint import CHECKPOINT_SCHEMA_VERSION, Checkpoint + + +class TestCheckpoint: + """Checkpoint dataclass serialization tests.""" + + def test_round_trip(self): + checkpoint = Checkpoint( + position="after_model", + cycle_index=1, + snapshot={"messages": []}, + app_data={"workflow_id": "wf-123"}, + ) + data = checkpoint.to_dict() + restored = Checkpoint.from_dict(data) + + assert restored.position == checkpoint.position + assert restored.cycle_index == checkpoint.cycle_index + assert restored.snapshot == checkpoint.snapshot + assert restored.app_data == checkpoint.app_data + assert restored.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_immutable(self): + checkpoint = Checkpoint(position="after_tools") + assert checkpoint.schema_version == CHECKPOINT_SCHEMA_VERSION + + def test_schema_version_mismatch_raises(self): + data = Checkpoint(position="after_model").to_dict() + data["schema_version"] = "0.0" + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data) + + def test_defaults(self): + checkpoint = Checkpoint(position="after_model") + assert checkpoint.cycle_index == 0 + assert checkpoint.snapshot == {} + assert checkpoint.app_data == {} + + def test_from_dict_warns_on_unknown_fields(self, caplog): + data = Checkpoint(position="after_tools").to_dict() + data["unknown_future_field"] = "something" + restored = Checkpoint.from_dict(data) + assert restored.position == "after_tools" + assert "unknown_future_field" in caplog.text + + def test_from_dict_missing_schema_version_raises(self): + data = {"position": "after_model", "cycle_index": 0, "snapshot": {}, "app_data": {}} + with pytest.raises(ValueError, match="not compatible with current version"): + Checkpoint.from_dict(data) From c723e5287621fcf3fd2a4bdb178df8bc9e12d3ae Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Tue, 21 Apr 2026 17:12:47 -0400 Subject: [PATCH 235/279] feat: add context_window_limit to model configs (#2176) --- src/strands/models/__init__.py | 3 ++- src/strands/models/anthropic.py | 6 +++--- src/strands/models/bedrock.py | 6 +++--- src/strands/models/gemini.py | 6 +++--- src/strands/models/litellm.py | 5 +++-- src/strands/models/llamaapi.py | 6 +++--- src/strands/models/llamacpp.py | 5 ++--- src/strands/models/mistral.py | 6 +++--- src/strands/models/model.py | 23 ++++++++++++++++++++++- src/strands/models/ollama.py | 6 +++--- src/strands/models/openai.py | 6 +++--- src/strands/models/openai_responses.py | 4 ++-- src/strands/models/sagemaker.py | 3 ++- src/strands/models/writer.py | 6 +++--- tests/strands/models/test_bedrock.py | 9 +++++++++ tests/strands/models/test_gemini.py | 9 +++++++++ tests/strands/models/test_model.py | 21 +++++++++++++++++++++ tests/strands/models/test_openai.py | 9 +++++++++ 18 files changed, 105 insertions(+), 34 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 2c582d116..3a23e257a 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,11 +7,12 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, Model __all__ = [ "bedrock", "model", + "BaseModelConfig", "BedrockModel", "CacheConfig", "Model", diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 818a8f14c..526460184 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -8,7 +8,7 @@ import logging import mimetypes from collections.abc import AsyncGenerator -from typing import Any, TypedDict, TypeVar, cast +from typing import Any, TypeVar, cast import anthropic from pydantic import BaseModel @@ -21,7 +21,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec from ._validation import _has_location_source, validate_config_keys -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ class AnthropicModel(Model): "input and output tokens exceed your context limit", } - class AnthropicConfig(TypedDict, total=False): + class AnthropicConfig(BaseModelConfig, total=False): """Configuration options for Anthropic models. Attributes: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 742cd82e9..a4c4aaba0 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -15,7 +15,7 @@ from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from strands.types.media import S3Location, SourceLocation @@ -31,7 +31,7 @@ from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys -from .model import CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, Model logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ class BedrockModel(Model): - Context window overflow detection """ - class BedrockConfig(TypedDict, total=False): + class BedrockConfig(BaseModelConfig, total=False): """Configuration options for Bedrock models. Attributes: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c94570293..81c8bd76f 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -9,7 +9,7 @@ import mimetypes import secrets from collections.abc import AsyncGenerator -from typing import Any, TypedDict, TypeVar, cast +from typing import Any, TypeVar, cast import pydantic from google import genai @@ -20,7 +20,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ class GeminiModel(Model): - Docs: https://ai.google.dev/api """ - class GeminiConfig(TypedDict, total=False): + class GeminiConfig(BaseModelConfig, total=False): """Configuration options for Gemini models. Attributes: diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 36bdb5a05..04e39a66f 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -6,7 +6,7 @@ import json import logging from collections.abc import AsyncGenerator -from typing import Any, TypedDict, TypeVar, cast +from typing import Any, TypeVar, cast import litellm from litellm.exceptions import ContextWindowExceededError @@ -21,6 +21,7 @@ from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec, ToolUse from ._validation import validate_config_keys +from .model import BaseModelConfig from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" - class LiteLLMConfig(TypedDict, total=False): + class LiteLLMConfig(BaseModelConfig, total=False): """Configuration options for LiteLLM models. Attributes: diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index b1ed4563a..71db9b78d 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -14,14 +14,14 @@ import llama_api_client from llama_api_client import LlamaAPIClient from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class LlamaAPIModel(Model): """Llama API model provider implementation.""" - class LlamaConfig(TypedDict, total=False): + class LlamaConfig(BaseModelConfig, total=False): """Configuration options for Llama API models. Attributes: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c52509816..d689e65ea 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -17,7 +17,6 @@ from collections.abc import AsyncGenerator from typing import ( Any, - TypedDict, TypeVar, cast, ) @@ -31,7 +30,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -86,7 +85,7 @@ class LlamaCppModel(Model): >>> response = agent(image_content) """ - class LlamaCppConfig(TypedDict, total=False): + class LlamaCppConfig(BaseModelConfig, total=False): """Configuration options for llama.cpp models. Attributes: diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..c4a23b244 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -11,14 +11,14 @@ import mistralai from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class MistralModel(Model): - System prompts """ - class MistralConfig(TypedDict, total=False): + class MistralConfig(BaseModelConfig, total=False): """Configuration parameters for Mistral models. Attributes: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..438d6a7ba 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -4,7 +4,7 @@ import logging from collections.abc import AsyncGenerator, AsyncIterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar from pydantic import BaseModel @@ -22,6 +22,17 @@ T = TypeVar("T", bound=BaseModel) +class BaseModelConfig(TypedDict, total=False): + """Base configuration shared by all model providers. + + Attributes: + context_window_limit: Maximum context window size in tokens for the model. + This value represents the total token capacity shared between input and output. + """ + + context_window_limit: int | None + + @dataclass class CacheConfig: """Configuration for prompt caching. @@ -51,6 +62,16 @@ def stateful(self) -> bool: """ return False + @property + def context_window_limit(self) -> int | None: + """Maximum context window size in tokens, or None if not configured.""" + config = self.get_config() + return ( + config.get("context_window_limit") + if isinstance(config, dict) + else getattr(config, "context_window_limit", None) + ) + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..41907e2e0 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -10,13 +10,13 @@ import ollama from pydantic import BaseModel -from typing_extensions import TypedDict, Unpack, override +from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -33,7 +33,7 @@ class OllamaModel(Model): - Tool/function calling """ - class OllamaConfig(TypedDict, total=False): + class OllamaConfig(BaseModelConfig, total=False): """Configuration parameters for Ollama models. Attributes: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 333f59c71..c4be7d360 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -9,7 +9,7 @@ import mimetypes from collections.abc import AsyncGenerator, AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Protocol, TypedDict, TypeVar, cast +from typing import Any, Protocol, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -22,7 +22,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ class OpenAIModel(Model): client: Client - class OpenAIConfig(TypedDict, total=False): + class OpenAIConfig(BaseModelConfig, total=False): """Configuration options for OpenAI models. Attributes: diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 30e4e2fa1..f845c2688 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -59,7 +59,7 @@ from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 from ._validation import validate_config_keys # noqa: E402 -from .model import Model # noqa: E402 +from .model import BaseModelConfig, Model # noqa: E402 logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class OpenAIResponsesModel(Model): client: Client client_args: dict[str, Any] - class OpenAIResponsesConfig(TypedDict, total=False): + class OpenAIResponsesConfig(BaseModelConfig, total=False): """Configuration options for OpenAI Responses API models. Attributes: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 775969290..0d206fd0b 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -17,6 +17,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from .model import BaseModelConfig from .openai import OpenAIModel T = TypeVar("T", bound=BaseModel) @@ -116,7 +117,7 @@ class SageMakerAIPayloadSchema(TypedDict, total=False): tool_results_as_user_messages: bool | None additional_args: dict[str, Any] | None - class SageMakerAIEndpointConfig(TypedDict, total=False): + class SageMakerAIEndpointConfig(BaseModelConfig, total=False): """Configuration options for SageMaker models. Attributes: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 94774b363..3e3276106 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -8,7 +8,7 @@ import logging import mimetypes from collections.abc import AsyncGenerator -from typing import Any, TypedDict, TypeVar, cast +from typing import Any, TypeVar, cast import writerai from pydantic import BaseModel @@ -19,7 +19,7 @@ from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported -from .model import Model +from .model import BaseModelConfig, Model logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class WriterModel(Model): """Writer API model provider implementation.""" - class WriterConfig(TypedDict, total=False): + class WriterConfig(BaseModelConfig, total=False): """Configuration options for Writer API. Attributes: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 05f0fa92f..a688a9962 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -288,6 +288,15 @@ def test__init__model_config(bedrock_client): assert tru_max_tokens == exp_max_tokens +def test__init__context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(context_window_limit=200_000) + + assert model.get_config().get("context_window_limit") == 200_000 + assert model.context_window_limit == 200_000 + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index ba4b2b53f..361508327 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -70,6 +70,15 @@ def test__init__model_configs(gemini_client, model_id): assert tru_temperature == exp_temperature +def test__init__context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash", context_window_limit=1_048_576) + + assert model.get_config().get("context_window_limit") == 1_048_576 + assert model.context_window_limit == 1_048_576 + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 458e98645..97010a722 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -184,6 +184,27 @@ async def stream(self, messages, tool_specs=None, system_prompt=None, *, tool_ch assert events[1]["contentBlockDelta"]["delta"]["text"] == "No tool choice" +def test_context_window_limit_from_dict_config(): + class DictConfigModel(SAModel): + def update_config(self, **model_config): + pass + + def get_config(self): + return {"context_window_limit": 200_000} + + async def structured_output(self, output_model, prompt=None, system_prompt=None, **kwargs): + yield {} + + async def stream(self, messages, tool_specs=None, system_prompt=None): + yield {} + + assert DictConfigModel().context_window_limit == 200_000 + + +def test_context_window_limit_none_when_not_configured(model): + assert model.context_window_limit is None + + def test_stateful_false(model): """Model.stateful defaults to False.""" assert not model.stateful diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 7af39032c..94e4caa3f 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -89,6 +89,15 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id +def test__init__context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o", context_window_limit=128_000) + + assert model.get_config().get("context_window_limit") == 128_000 + assert model.context_window_limit == 128_000 + + @pytest.mark.parametrize( "content, exp_result", [ From 255b7674f012a55a6c4b236a1e256668696c73a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minoru=20Onda=EF=BC=88=E3=81=BF=E3=81=AE=E3=82=8B=E3=82=93?= =?UTF-8?q?=EF=BC=89?= <74597894+minorun365@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:39:20 +0900 Subject: [PATCH 236/279] fix(mcp): skip MCPClient cleanup during interpreter finalization (#2144) Co-authored-by: minorun365 --- src/strands/tools/mcp/mcp_client.py | 10 ++++++++ tests/strands/tools/mcp/test_mcp_client.py | 27 ++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index e81dc7130..3b0d656f9 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -12,6 +12,7 @@ import contextvars import json import logging +import sys import threading import uuid from asyncio import AbstractEventLoop @@ -343,6 +344,15 @@ def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, ex """ self._log_debug_with_thread("exiting MCPClient context") + # Skip cleanup during interpreter finalization. On Python 3.14+, joining a + # non-daemon thread at shutdown raises PythonFinalizationError; even though + # our background thread is a daemon and will be reclaimed automatically, + # the join call itself produces noisy tracebacks on stderr when the GC + # reaches Agent.__del__ during finalization. See issue #2143. + if sys.is_finalizing(): + self._log_debug_with_thread("interpreter is finalizing, skipping MCPClient cleanup") + return + # Only try to signal close future if we have a background thread if self._background_thread is not None: # Signal close future if event loop exists diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index bf0e7ce8e..f1bf7dd73 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -594,6 +594,33 @@ def test_stop_closes_event_loop(): assert client._background_thread_event_loop is None +def test_stop_skips_cleanup_during_interpreter_finalization(): + """Test that stop() is a no-op when the interpreter is finalizing. + + On Python 3.14+, threading.Thread.join() raises PythonFinalizationError at + shutdown. The background thread is a daemon and is reclaimed automatically, + so stop() should skip join() and event loop cleanup to avoid noisy + tracebacks surfaced via Agent.__del__ during GC. See issue #2143. + """ + client = MCPClient(MagicMock()) + + mock_thread = MagicMock() + mock_event_loop = MagicMock() + client._background_thread = mock_thread + client._background_thread_event_loop = mock_event_loop + + with patch("strands.tools.mcp.mcp_client.sys.is_finalizing", return_value=True): + # Must not raise, and must not touch the thread or event loop. + client.stop(None, None, None) + + mock_thread.join.assert_not_called() + mock_event_loop.close.assert_not_called() + # State is intentionally left alone during finalization — the interpreter + # is going away and cleanup is unnecessary. + assert client._background_thread is mock_thread + assert client._background_thread_event_loop is mock_event_loop + + def test_mcp_client_state_reset_after_timeout(): """Test that all client state is properly reset after timeout.""" From 50439e01514c9a8bf59ca041a2699367a0263a17 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Wed, 22 Apr 2026 12:11:56 -0400 Subject: [PATCH 237/279] fix(tests): update retired claude-3-haiku model in integration tests (#2186) --- .../test_summarizing_conversation_manager_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests_integ/test_summarizing_conversation_manager_integration.py b/tests_integ/test_summarizing_conversation_manager_integration.py index d6508edce..b6ba8b854 100644 --- a/tests_integ/test_summarizing_conversation_manager_integration.py +++ b/tests_integ/test_summarizing_conversation_manager_integration.py @@ -34,7 +34,7 @@ def model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-haiku-20240307", # Using Haiku for faster/cheaper tests + model_id="claude-haiku-4-5-20251001", # Using Haiku for faster/cheaper tests max_tokens=1024, ) @@ -46,7 +46,7 @@ def summarization_model(): client_args={ "api_key": os.getenv("ANTHROPIC_API_KEY"), }, - model_id="claude-3-haiku-20240307", + model_id="claude-haiku-4-5-20251001", max_tokens=512, ) From 3e08d5e52fd5a713ab3dbaa6c89de5ba86c8a5ef Mon Sep 17 00:00:00 2001 From: Zelys Date: Wed, 22 Apr 2026 14:21:19 -0500 Subject: [PATCH 238/279] feat(mcp): preserve CallToolResult.isError flag in MCPToolResult (#2118) --- src/strands/tools/mcp/mcp_client.py | 2 ++ src/strands/tools/mcp/mcp_types.py | 6 ++++++ tests/strands/tools/mcp/test_mcp_client.py | 13 +++++++++++++ 3 files changed, 21 insertions(+) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 3b0d656f9..2ac632925 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -754,6 +754,8 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes result["structuredContent"] = call_tool_result.structuredContent if call_tool_result.meta: result["metadata"] = call_tool_result.meta + if call_tool_result.isError is not None: + result["isError"] = call_tool_result.isError return result diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 8fbf573be..09feb624f 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -61,7 +61,13 @@ class MCPToolResult(ToolResult): metadata: Optional arbitrary metadata returned by the MCP tool. This field allows MCP servers to attach custom metadata to tool results (e.g., token usage, performance metrics, or business-specific tracking information). + isError: Whether the MCP tool reported an application-level error via + ``CallToolResult.isError``. ``True`` means the tool executed but its logic + returned a failure. Absent when the tool succeeded or when the error was a + protocol/client exception rather than a tool-reported failure, letting + callers distinguish application errors from transport/protocol errors. """ structuredContent: NotRequired[dict[str, Any]] metadata: NotRequired[dict[str, Any]] + isError: NotRequired[bool] diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f1bf7dd73..fe439c5d9 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -132,6 +132,8 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ assert result["content"][0]["text"] == "Test message" # No structured content should be present when not provided by MCP assert result.get("structuredContent") is None + # isError mirrors the MCP server's explicit value; absent only for protocol/client exceptions + assert result.get("isError") is is_error def test_call_tool_sync_session_not_active(): @@ -261,6 +263,8 @@ async def mock_awaitable(): assert result["toolUseId"] == "test-123" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "Test message" + # isError mirrors the MCP server's explicit value; absent only for protocol/client exceptions + assert result.get("isError") is is_error @pytest.mark.asyncio @@ -408,6 +412,15 @@ def test_mcp_tool_result_type(): assert result_with_structured["structuredContent"] == {"key": "value"} + # isError is optional — absent by default + assert "isError" not in result + + # isError can be set to flag tool-reported application errors + result_with_is_error = MCPToolResult( + status="error", toolUseId="test-789", content=[{"text": "Tool failed"}], isError=True + ) + assert result_with_is_error["isError"] is True + def test_call_tool_sync_without_structured_content(mock_transport, mock_session): """Test that call_tool_sync works correctly when no structured content is provided.""" From 5a6df59502dc618781b85e80b01706a19cd45828 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Wed, 22 Apr 2026 15:52:22 -0400 Subject: [PATCH 239/279] feat: add `count_token` method to model with naive estimation using tiktoken (#2031) Co-authored-by: opieter-aws --- src/strands/models/model.py | 196 ++++++++++++++- tests/strands/models/test_model.py | 370 +++++++++++++++++++++++++++++ 2 files changed, 564 insertions(+), 2 deletions(-) diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 438d6a7ba..e5b15ebaa 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,8 +1,11 @@ """Abstract base class for Agent model providers.""" import abc +import functools +import json import logging -from collections.abc import AsyncGenerator, AsyncIterable +import math +from collections.abc import AsyncGenerator, AsyncIterable, Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar @@ -10,7 +13,7 @@ from ..hooks.events import AfterInvocationEvent from ..plugins.plugin import Plugin -from ..types.content import Messages, SystemContentBlock +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -21,6 +24,164 @@ T = TypeVar("T", bound=BaseModel) +_DEFAULT_ENCODING = "cl100k_base" + + +def _heuristic_estimate_text(text: str) -> int: + """Estimate token count from text using characters / 4 heuristic.""" + return math.ceil(len(text) / 4) + + +def _heuristic_estimate_json(obj: Any) -> int: + """Estimate token count from a JSON-serializable object using characters / 2 heuristic.""" + try: + return math.ceil(len(json.dumps(obj)) / 2) + except (TypeError, ValueError): + return 0 + + +@functools.lru_cache(maxsize=1) +def _get_encoding() -> Any: + """Get the default tiktoken encoding, caching to avoid repeated lookups. + + Returns: + The tiktoken encoding, or None if tiktoken is not installed. + """ + try: + import tiktoken + + return tiktoken.get_encoding(_DEFAULT_ENCODING) + except ImportError: + logger.debug("tiktoken not available, falling back to heuristic token estimation") + return None + + +def _count_content_block_tokens( + block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int] +) -> int: + """Count tokens for a single content block. + + Args: + block: The content block to count tokens for. + count_text: Function that returns token count for a text string. + count_json: Function that returns token count for a JSON-serializable object. + """ + total = 0 + + if "text" in block: + total += count_text(block["text"]) + + if "toolUse" in block: + tool_use = block["toolUse"] + total += count_text(tool_use.get("name", "")) + total += count_json(tool_use.get("input", {})) + + if "toolResult" in block: + tool_result = block["toolResult"] + for item in tool_result.get("content", []): + if "text" in item: + total += count_text(item["text"]) + + if "reasoningContent" in block: + reasoning = block["reasoningContent"] + if "reasoningText" in reasoning: + reasoning_text = reasoning["reasoningText"] + if "text" in reasoning_text: + total += count_text(reasoning_text["text"]) + + if "guardContent" in block: + guard = block["guardContent"] + if "text" in guard and "text" in guard["text"]: + total += count_text(guard["text"]["text"]) + + if "citationsContent" in block: + citations = block["citationsContent"] + if "content" in citations: + for citation_item in citations["content"]: + if "text" in citation_item: + total += count_text(citation_item["text"]) + + return total + + +def _estimate_tokens_with_tiktoken( + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, +) -> int: + """Estimate tokens by serializing messages/tools to text and counting with tiktoken. + + This is a best-effort fallback for providers that don't expose native counting. + Accuracy varies by model but is sufficient for threshold-based decisions. + + Raises: + ImportError: If tiktoken is not installed. + """ + encoding = _get_encoding() + if encoding is None: + raise ImportError("tiktoken is not available") + + def count_text(text: str) -> int: + return len(encoding.encode(text)) + + def count_json(obj: Any) -> int: + try: + return len(encoding.encode(json.dumps(obj))) + except (TypeError, ValueError): + return 0 + + total = 0 + + # Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting, + # since providers wrap system_prompt into system_prompt_content when both are provided. + if system_prompt_content: + for block in system_prompt_content: + if "text" in block: + total += count_text(block["text"]) + elif system_prompt: + total += count_text(system_prompt) + + for message in messages: + for block in message["content"]: + total += _count_content_block_tokens(block, count_text, count_json) + + if tool_specs: + for spec in tool_specs: + total += count_json(spec) + + return total + + +def _estimate_tokens_with_heuristic( + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, +) -> int: + """Estimate tokens using character-based heuristics (text: chars/4, JSON: chars/2). + + Dependency-free fallback when tiktoken is not installed. + """ + total = 0 + + if system_prompt_content: + for block in system_prompt_content: + if "text" in block: + total += _heuristic_estimate_text(block["text"]) + elif system_prompt: + total += _heuristic_estimate_text(system_prompt) + + for message in messages: + for block in message["content"]: + total += _count_content_block_tokens(block, _heuristic_estimate_text, _heuristic_estimate_json) + + if tool_specs: + for spec in tool_specs: + total += _heuristic_estimate_json(spec) + + return total + class BaseModelConfig(TypedDict, total=False): """Base configuration shared by all model providers. @@ -151,6 +312,37 @@ def stream( """ pass + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Estimate token count for the given input before sending to the model. + + Used for proactive context management (e.g., triggering compression at a threshold). + Uses tiktoken's cl100k_base encoding when available, otherwise falls back to a + heuristic (characters / 4 for text, characters / 2 for JSON). Accuracy varies by + model provider. Not intended for billing or precise quota calculations. + + Subclasses may override this method to provide model-specific token counting + using native APIs for improved accuracy. + + Args: + messages: List of message objects to estimate tokens for. + tool_specs: List of tool specifications to include in the estimate. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt. + + Returns: + Estimated total input tokens. + """ + try: + return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content) + except ImportError: + return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content) + class _ModelPlugin(Plugin): """Plugin that manages model-related lifecycle hooks.""" diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 97010a722..11d4c10b9 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -234,3 +234,373 @@ def test_model_plugin_preserves_messages_when_not_stateful(model_plugin): model_plugin._on_after_invocation(event) assert len(agent.messages) == 1 + + +@pytest.mark.asyncio +async def test_count_tokens_empty_messages(model): + assert await model.count_tokens(messages=[]) == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_only(model): + result = await model.count_tokens(messages=[], system_prompt="You are a helpful assistant.") + assert result == 6 + + +@pytest.mark.asyncio +async def test_count_tokens_text_messages(model, messages): + result = await model.count_tokens(messages=messages) + assert result == 1 # "hello" + + +@pytest.mark.asyncio +async def test_count_tokens_with_tool_specs(model, messages, tool_specs): + without_tools = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs) + assert without_tools == 1 # "hello" + assert with_tools == 49 # "hello" (1) + tool_spec (48) + + +@pytest.mark.asyncio +async def test_count_tokens_with_system_prompt(model, messages, system_prompt): + without_prompt = await model.count_tokens(messages=messages) + with_prompt = await model.count_tokens(messages=messages, system_prompt=system_prompt) + assert without_prompt == 1 # "hello" + assert with_prompt == 3 # "hello" (1) + "s1" (2) + + +@pytest.mark.asyncio +async def test_count_tokens_combined(model, messages, tool_specs, system_prompt): + result = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt) + assert result == 51 # "hello" (1) + tool_spec (48) + "s1" (2) + + +@pytest.mark.asyncio +async def test_count_tokens_tool_use_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"query": "test"}, + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + # name "my_tool" (2) + json.dumps(input) (6) = 8 + assert result == 8 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_block(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "tool output here"}], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 3 # "tool output here" + + +@pytest.mark.asyncio +async def test_count_tokens_reasoning_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": { + "text": "Let me think about this step by step.", + } + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 9 # "Let me think about this step by step." + + +@pytest.mark.asyncio +async def test_count_tokens_skips_binary_content(model): + messages = [ + { + "role": "user", + "content": [{"image": {"format": "png", "source": {"bytes": b"fake image data"}}}], + } + ] + assert await model.count_tokens(messages=messages) == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_with_bytes_only(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [{"image": {"format": "png", "source": {"bytes": b"image data"}}}], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 0 + + +@pytest.mark.asyncio +async def test_count_tokens_tool_result_with_text_and_bytes(model): + messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "123", + "content": [ + {"text": "Here is the screenshot"}, + {"image": {"format": "png", "source": {"bytes": b"image data"}}}, + ], + "status": "success", + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result > 0 + + +@pytest.mark.asyncio +async def test_count_tokens_guard_content_block(model): + messages = [ + { + "role": "assistant", + "content": [{"guardContent": {"text": {"text": "This content was filtered by guardrails."}}}], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 8 # "This content was filtered by guardrails." + + +@pytest.mark.asyncio +async def test_count_tokens_tool_use_with_bytes(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "my_tool", + "input": {"data": b"binary data"}, + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + # Should still count the tool name even though input has non-serializable bytes + assert result == 2 # "my_tool" name only + + +@pytest.mark.asyncio +async def test_count_tokens_non_serializable_tool_spec(model, messages): + tool_specs = [ + { + "name": "test", + "description": "a tool", + "inputSchema": {"json": {"default": b"bytes"}}, + } + ] + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + # Should still count the message tokens even though tool spec fails + assert result == 1 # "hello" only, tool spec skipped + + +@pytest.mark.asyncio +async def test_count_tokens_citations_block(model): + messages = [ + { + "role": "assistant", + "content": [ + { + "citationsContent": { + "content": [{"text": "According to the document, the answer is 42."}], + "citations": [], + } + } + ], + } + ] + result = await model.count_tokens(messages=messages) + assert result == 11 # "According to the document, the answer is 42." + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content(model): + result = await model.count_tokens( + messages=[], + system_prompt_content=[{"text": "You are a helpful assistant."}], + ) + assert result == 6 # "You are a helpful assistant." + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content_with_cache_point(model): + result = await model.count_tokens( + messages=[], + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default"}}, + ], + ) + assert result == 6 # "You are a helpful assistant.", cachePoint adds 0 + + +@pytest.mark.asyncio +async def test_count_tokens_system_prompt_content_takes_priority(model): + content_only = await model.count_tokens( + messages=[], + system_prompt_content=[{"text": "Short."}], + ) + # When both are provided, system_prompt_content wins — system_prompt is ignored + both = await model.count_tokens( + messages=[], + system_prompt="This is a much longer system prompt that should have more tokens.", + system_prompt_content=[{"text": "Short."}], + ) + assert content_only == 2 # "Short." + assert content_only == both + + +@pytest.mark.asyncio +async def test_count_tokens_all_inputs(model): + messages = [ + {"role": "user", "content": [{"text": "hello world"}]}, + {"role": "assistant", "content": [{"text": "hi there"}]}, + ] + result = await model.count_tokens( + messages=messages, + tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}], + system_prompt="Be helpful.", + system_prompt_content=[{"text": "Additional system context."}], + ) + # system_prompt_content (4) + "hello world" (2) + "hi there" (2) + tool_spec (23) = 31 + assert result == 31 + + +def test_get_encoding_falls_back_without_tiktoken(monkeypatch): + """Test that _get_encoding returns None and count_tokens falls back to heuristic.""" + import strands.models.model as model_module + + model_module._get_encoding.cache_clear() + original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def _block_tiktoken(name, *args, **kwargs): + if name == "tiktoken": + raise ImportError("No module named 'tiktoken'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", _block_tiktoken) + + try: + assert model_module._get_encoding() is None + + # _estimate_tokens_with_tiktoken should raise when tiktoken is unavailable + with pytest.raises(ImportError): + model_module._estimate_tokens_with_tiktoken( + messages=[{"role": "user", "content": [{"text": "hello world!"}]}], + ) + + # _estimate_tokens_with_heuristic uses chars/4 for text + result = model_module._estimate_tokens_with_heuristic( + messages=[{"role": "user", "content": [{"text": "hello world!"}]}], + ) + assert result == 3 # ceil(12 / 4) + finally: + model_module._get_encoding.cache_clear() + + +class TestHeuristicEstimation: + """Tests for _estimate_tokens_with_heuristic.""" + + def test_all_content_types(self): + """One call covering text, toolUse, toolResult, reasoning, guard, citations, system prompt, tool specs.""" + from strands.models.model import _estimate_tokens_with_heuristic + + messages = [ + {"role": "user", "content": [{"text": "hello world!"}]}, + {"role": "assistant", "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, + {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, + {"guardContent": {"text": {"text": "Filtered."}}}, + {"citationsContent": {"content": [{"text": "Citation."}]}}, + ]}, + {"role": "user", "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, + ]}, + ] + result = _estimate_tokens_with_heuristic( + messages=messages, + tool_specs=[{"name": "test", "description": "a tool"}], + system_prompt="ignored", + system_prompt_content=[{"text": "Be helpful."}], + ) + assert result > 0 + + def test_non_serializable_inputs(self): + """Heuristic gracefully handles non-serializable tool input and tool specs.""" + from strands.models.model import _estimate_tokens_with_heuristic + + result = _estimate_tokens_with_heuristic( + messages=[ + {"role": "assistant", "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, + ]}, + ], + tool_specs=[{"name": "t", "inputSchema": {"json": {"default": b"bytes"}}}], + ) + assert result == 2 # only tool name counted: ceil(len("my_tool") / 4) + + @pytest.mark.asyncio + async def test_model_falls_back_to_heuristic(self, monkeypatch, model): + """Model.count_tokens falls back to heuristic when tiktoken unavailable.""" + import strands.models.model as model_module + + model_module._get_encoding.cache_clear() + original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def _block_tiktoken(name, *args, **kwargs): + if name == "tiktoken": + raise ImportError("No module named 'tiktoken'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", _block_tiktoken) + + try: + result = await model.count_tokens( + messages=[{"role": "user", "content": [{"text": "hello world!"}]}] + ) + assert result == 3 # ceil(12 / 4) + finally: + model_module._get_encoding.cache_clear() From 4e9ed269560530402ec9a42d32e567f8537c076b Mon Sep 17 00:00:00 2001 From: poshinchen Date: Thu, 23 Apr 2026 10:17:05 -0400 Subject: [PATCH 240/279] chore(log): added warning for default model awareness and is subject to change (#2164) --- src/strands/models/bedrock.py | 8 +++++++- tests/strands/models/test_bedrock.py | 22 ++++++++++++---------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index a4c4aaba0..7ff3024a8 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -1123,4 +1123,10 @@ def _get_default_model_with_warning(region_name: str, model_config: BedrockConfi stacklevel=2, ) - return _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) + default_model_id = _DEFAULT_BEDROCK_MODEL_ID.format(prefix_inference_map.get(prefix, prefix)) + warnings.warn( + f"You're using default model '{default_model_id}', which is subject to change. " + "Specify a model explicitly to pin the model target.", + stacklevel=2, + ) + return default_model_id diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a688a9962..384ee05e1 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2169,25 +2169,25 @@ def test_get_default_model_with_warning_supported_regions_shows_no_warning(captu """Test get_model_prefix_with_warning doesn't warn for supported region prefixes.""" BedrockModel._get_default_model_with_warning("us-west-2") BedrockModel._get_default_model_with_warning("eu-west-2") - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("eu-west-1") assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("us-east-1") assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1") assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0" - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings): @@ -2199,9 +2199,10 @@ def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warni def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings): """Test _get_default_model_with_warning warns for unsupported regions.""" BedrockModel._get_default_model_with_warning("ca-central-1") - assert len(captured_warnings) == 1 - assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) - assert "our default inference endpoint" in str(captured_warnings[0].message) + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 1 + assert "This region ca-central-1 does not support" in str(region_warnings[0].message) + assert "our default inference endpoint" in str(region_warnings[0].message) def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings): @@ -2217,8 +2218,9 @@ def test_init_with_unsupported_region_warns(session_cls, captured_warnings): """Test BedrockModel initialization warns for unsupported regions.""" BedrockModel(region_name="ca-central-1") - assert len(captured_warnings) == 1 - assert "This region ca-central-1 does not support" in str(captured_warnings[0].message) + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 1 + assert "This region ca-central-1 does not support" in str(region_warnings[0].message) def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings): @@ -2237,7 +2239,7 @@ def test_no_override_uses_formatted_default_model_id(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("us-east-1") assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" assert model_id != _DEFAULT_BEDROCK_MODEL_ID - assert len(captured_warnings) == 0 + assert all("does not support" not in str(w.message) for w in captured_warnings) def test_custom_model_id_not_overridden_by_region_formatting(session_cls): From b207e0383cbde2e9eb0ee8aa4fe86f4989fbddab Mon Sep 17 00:00:00 2001 From: Elliott <57333066+ElliottJW@users.noreply.github.com> Date: Thu, 23 Apr 2026 11:26:42 -0500 Subject: [PATCH 241/279] fix(litellm): forward ttl field from CachePoint in _format_system_messages (#2153) Co-authored-by: Elliott Jacobsen-Watts Co-authored-by: Claude Opus 4.7 --- src/strands/models/litellm.py | 7 ++++-- src/strands/types/content.py | 3 +++ tests/strands/models/test_litellm.py | 33 ++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 04e39a66f..9fbdff794 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -221,11 +221,14 @@ def _format_system_messages( for block in system_prompt_content or []: if "text" in block: system_content.append({"type": "text", "text": block["text"]}) - elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + elif "cachePoint" in block and block["cachePoint"]["type"] == "default": # Apply cache control to the immediately preceding content block # for LiteLLM/Anthropic compatibility if system_content: - system_content[-1]["cache_control"] = {"type": "ephemeral"} + cache_control: dict[str, Any] = {"type": "ephemeral"} + if ttl := block["cachePoint"].get("ttl"): + cache_control["ttl"] = ttl + system_content[-1]["cache_control"] = cache_control # Create single system message with content array rather than mulitple system messages return [{"role": "system", "content": system_content}] if system_content else [] diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 8db1d1d98..5f9cc1460 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -67,9 +67,12 @@ class CachePoint(TypedDict): Attributes: type: The type of cache point, typically "default". + ttl: Optional cache TTL duration (e.g. "5m", "1h"). Supported by providers + that accept Anthropic-compatible cache_control fields. """ type: str + ttl: NotRequired[str] class ContentBlock(TypedDict, total=False): diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index d35a1806e..96cf561cd 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -955,6 +955,39 @@ def test_format_request_message_tool_call_no_reasoning_signature(): assert "__thought__" not in result["id"] +def test_format_system_messages_preserves_cache_point_ttl(): + """CachePoint with ttl="1h" should produce cache_control with ttl field.""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + ] + ) + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral", "ttl": "1h"} + + +def test_format_system_messages_cache_point_without_ttl(): + """CachePoint without ttl should produce cache_control with no ttl key (backward compat).""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"text": "You are a helpful assistant."}, + {"cachePoint": {"type": "default"}}, + ] + ) + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral"} + assert "ttl" not in result[0]["content"][0]["cache_control"] + + +def test_format_system_messages_cache_point_with_no_preceding_content(): + """CachePoint with no preceding text block should be silently ignored.""" + result = LiteLLMModel._format_system_messages( + system_prompt_content=[ + {"cachePoint": {"type": "default", "ttl": "1h"}}, + ] + ) + assert result == [] + + def test_thought_signature_round_trip(): """Test that thought signature is preserved through a full response -> internal -> request cycle.""" model = LiteLLMModel(model_id="test") From 2eaff9c2ffa43e052fc4546da334ca3dc2538d5d Mon Sep 17 00:00:00 2001 From: mattdai01 <32076552+mattdai01@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:18:00 -0700 Subject: [PATCH 242/279] =?UTF-8?q?fix(skills):=20preserve=20cache=20point?= =?UTF-8?q?s=20in=20system=20prompt=20during=20skills=20inj=E2=80=A6=20(#2?= =?UTF-8?q?134)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Matthew Dai Co-authored-by: Claude Sonnet 4.6 --- src/strands/agent/agent.py | 12 +++ .../vended_plugins/skills/agent_skills.py | 54 ++++++++---- tests/strands/agent/test_agent.py | 27 ++++++ .../skills/test_agent_skills.py | 84 +++++++++++++++++-- 4 files changed, 154 insertions(+), 23 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e8ea3c9bc..965969961 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -428,6 +428,18 @@ def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: """ self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property + def system_prompt_content(self) -> list[SystemContentBlock] | None: + """Get the system prompt as a list of content blocks. + + Returns the structured content block representation, preserving cache points + and other non-text blocks. Returns None if no system prompt is set. + + Returns: + The system prompt as a list of content blocks, or None if no system prompt is set. + """ + return list(self._system_prompt_content) if self._system_prompt_content is not None else None + @property def tool(self) -> _ToolCaller: """Call tool as a function. diff --git a/src/strands/vended_plugins/skills/agent_skills.py b/src/strands/vended_plugins/skills/agent_skills.py index 23217e81c..ded2afb79 100644 --- a/src/strands/vended_plugins/skills/agent_skills.py +++ b/src/strands/vended_plugins/skills/agent_skills.py @@ -15,6 +15,7 @@ from ...hooks.events import BeforeInvocationEvent from ...plugins import Plugin, hook from ...tools.decorator import tool +from ...types.content import SystemContentBlock from ...types.tools import ToolContext from .skill import Skill @@ -136,34 +137,51 @@ def skills(self, skill_name: str, tool_context: ToolContext) -> str: # noqa: D4 def _on_before_invocation(self, event: BeforeInvocationEvent) -> None: """Inject skill metadata into the system prompt before each invocation. - Removes the previously injected XML block (if any) via exact string - replacement, then appends a fresh one. Uses agent state to track the - injected XML per-agent, so a single plugin instance can be shared - across multiple agents safely. + Removes the previously injected XML block (if any) via exact match, + then appends a fresh one. Uses agent state to track the injected XML + per-agent, so a single plugin instance can be shared across multiple + agents safely. + + When the agent has a structured system prompt (list of SystemContentBlock), + the injection is done at the block level so that cache points and other + structured blocks are preserved. Otherwise falls back to string manipulation. Args: event: The before-invocation event containing the agent reference. """ agent = event.agent - current_prompt = agent.system_prompt or "" - - # Remove the previously injected XML block by exact match state_data = agent.state.get(self._state_key) last_injected_xml = state_data.get("last_injected_xml") if isinstance(state_data, dict) else None - if last_injected_xml is not None: - if last_injected_xml in current_prompt: - current_prompt = current_prompt.replace(last_injected_xml, "") - else: - logger.warning("unable to find previously injected skills XML in system prompt, re-appending") skills_xml = self._generate_skills_xml() - injection = f"\n\n{skills_xml}" - new_prompt = f"{current_prompt}{injection}" if current_prompt else skills_xml - - new_injected_xml = injection if current_prompt else skills_xml - self._set_state_field(agent, "last_injected_xml", new_injected_xml) - agent.system_prompt = new_prompt + content = agent.system_prompt_content + + if content is not None: + # Content-block path: preserve cache points and other structured blocks + blocks: list[SystemContentBlock] = list(content) + if last_injected_xml is not None: + injected_block: SystemContentBlock = {"text": last_injected_xml} + if injected_block in blocks: + blocks.remove(injected_block) + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + blocks.append({"text": skills_xml}) + self._set_state_field(agent, "last_injected_xml", skills_xml) + agent.system_prompt = blocks + else: + # String path: legacy behaviour for plain-string system prompts + current_prompt = agent.system_prompt or "" + if last_injected_xml is not None: + if last_injected_xml in current_prompt: + current_prompt = current_prompt.replace(last_injected_xml, "") + else: + logger.warning("unable to find previously injected skills XML in system prompt, re-appending") + injection = f"\n\n{skills_xml}" + new_prompt = f"{current_prompt}{injection}" if current_prompt else skills_xml + new_injected_xml = injection if current_prompt else skills_xml + self._set_state_field(agent, "last_injected_xml", new_injected_xml) + agent.system_prompt = new_prompt def get_available_skills(self) -> list[Skill]: """Get the list of available skills. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 1e27274a1..3b9258e0a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1166,6 +1166,33 @@ def test_system_prompt_setter_none(): assert agent._system_prompt_content is None +def test_system_prompt_content_string(): + """Test that system_prompt_content returns content blocks for string prompt.""" + agent = Agent(system_prompt="hello") + assert agent.system_prompt_content == [{"text": "hello"}] + + +def test_system_prompt_content_structured(): + """Test that system_prompt_content returns structured blocks with cache points.""" + blocks = [{"text": "You are helpful"}, {"cachePoint": {"type": "default"}}] + agent = Agent(system_prompt=blocks) + assert agent.system_prompt_content == blocks + + +def test_system_prompt_content_none(): + """Test that system_prompt_content returns None when no prompt is set.""" + agent = Agent(system_prompt=None) + assert agent.system_prompt_content is None + + +def test_system_prompt_content_returns_copy(): + """Test that system_prompt_content returns a defensive copy.""" + agent = Agent(system_prompt="hello") + content = agent.system_prompt_content + content.append({"text": "injected"}) + assert agent.system_prompt_content == [{"text": "hello"}] + + @pytest.mark.asyncio async def test_stream_async_passes_invocation_state(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_stream.side_effect = [ diff --git a/tests/strands/vended_plugins/skills/test_agent_skills.py b/tests/strands/vended_plugins/skills/test_agent_skills.py index db82355a9..03f43ef2c 100644 --- a/tests/strands/vended_plugins/skills/test_agent_skills.py +++ b/tests/strands/vended_plugins/skills/test_agent_skills.py @@ -32,11 +32,12 @@ def _mock_agent(): agent._system_prompt = "You are an agent." agent._system_prompt_content = [{"text": "You are an agent."}] - # Make system_prompt property behave like the real Agent + # Make system_prompt and system_prompt_content properties behave like the real Agent type(agent).system_prompt = property( lambda self: self._system_prompt, lambda self, value: _set_system_prompt(self, value), ) + type(agent).system_prompt_content = property(lambda self: self._system_prompt_content) agent.hooks = HookRegistry() agent.add_hook = MagicMock( @@ -59,11 +60,15 @@ def _mock_tool_context(agent: MagicMock) -> ToolContext: return ToolContext(tool_use=tool_use, agent=agent, invocation_state={"agent": agent}) -def _set_system_prompt(agent: MagicMock, value: str | None) -> None: +def _set_system_prompt(agent: MagicMock, value: str | list | None) -> None: """Simulate the Agent.system_prompt setter.""" if isinstance(value, str): agent._system_prompt = value agent._system_prompt_content = [{"text": value}] + elif isinstance(value, list): + text_parts = [block["text"] for block in value if "text" in block] + agent._system_prompt = "\n".join(text_parts) if text_parts else None + agent._system_prompt_content = value elif value is None: agent._system_prompt = None agent._system_prompt_content = None @@ -417,9 +422,41 @@ def test_uses_public_system_prompt_setter(self): event = BeforeInvocationEvent(agent=agent) plugin._on_before_invocation(event) - # The public setter should have been used, so _system_prompt_content - # should be consistent with _system_prompt - assert agent._system_prompt_content == [{"text": agent._system_prompt}] + # The public setter should have been used via the content-block path: + # original block is preserved and the skills XML is appended as a new block. + assert len(agent.system_prompt_content) == 2 + assert agent.system_prompt_content[0] == {"text": "Original."} + assert "" in agent.system_prompt_content[1]["text"] + + def test_preserves_cache_points_in_system_prompt(self): + """Test that cachePoint blocks in the system prompt are preserved after injection.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + agent._system_prompt = "Base instructions." + agent._system_prompt_content = [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + ] + + expected_skills_xml = plugin._generate_skills_xml() + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + # Exact block structure: original text, cachePoint, skills XML + assert agent.system_prompt_content == [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + {"text": expected_skills_xml}, + ] + + # Repeated invocation: identical result, no accumulation + plugin._on_before_invocation(event) + assert agent.system_prompt_content == [ + {"text": "Base instructions."}, + {"cachePoint": {"type": "default"}}, + {"text": expected_skills_xml}, + ] def test_warns_when_previous_xml_not_found(self, caplog): """Test that a warning is logged when the previously injected XML is missing from the prompt.""" @@ -441,6 +478,43 @@ def test_warns_when_previous_xml_not_found(self, caplog): assert "" in agent.system_prompt +class TestStringPathInjection: + """Tests for the string-path branch of _on_before_invocation (system_prompt_content is None).""" + + def test_string_path_replaces_previous_xml(self): + """Test that old injected XML is replaced when found in the string prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + old_xml = "\n\nxml" + agent._system_prompt = f"Base prompt.{old_xml}" + agent._system_prompt_content = None + agent.state.set(plugin._state_key, {"last_injected_xml": old_xml}) + + event = BeforeInvocationEvent(agent=agent) + plugin._on_before_invocation(event) + + assert "xml" not in agent.system_prompt + assert "" in agent.system_prompt + assert agent.system_prompt.startswith("Base prompt.") + + def test_string_path_warns_when_previous_xml_not_found(self, caplog): + """Test that a warning is logged when old XML is missing from the string prompt.""" + plugin = AgentSkills(skills=[_make_skill()]) + agent = _mock_agent() + + agent._system_prompt = "Totally new prompt." + agent._system_prompt_content = None + agent.state.set(plugin._state_key, {"last_injected_xml": "\n\nxml"}) + + event = BeforeInvocationEvent(agent=agent) + with caplog.at_level(logging.WARNING): + plugin._on_before_invocation(event) + + assert "unable to find previously injected skills XML in system prompt" in caplog.text + assert "" in agent.system_prompt + + class TestSkillsXmlGeneration: """Tests for _generate_skills_xml.""" From 513e67d3758ed835bf907498f571e3b94bb5edd5 Mon Sep 17 00:00:00 2001 From: Ratan Kokku <34803938+Ratansairohith@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:19:57 -0700 Subject: [PATCH 243/279] fix(ollama): generate unique toolUseId instead of reusing tool name (#2053) --- src/strands/models/ollama.py | 6 +- tests/strands/models/test_ollama.py | 105 ++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 22 deletions(-) diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 41907e2e0..54805ac16 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,6 +5,7 @@ import json import logging +import uuid from collections.abc import AsyncGenerator from typing import Any, TypeVar, cast @@ -124,7 +125,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> "tool_calls": [ { "function": { - "name": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], "arguments": content["toolUse"]["input"], } } @@ -246,7 +247,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: return {"contentBlockStart": {"start": {}}} tool_name = event["data"].function.name - return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + tool_use_id = f"tooluse_{uuid.uuid4().hex[:24]}" + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_use_id}}}} case "content_delta": if event["data_type"] == "text": diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 0d4fbb9e0..7a6bbf97c 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -1,5 +1,6 @@ import json import logging +import re import unittest.mock import pydantic @@ -127,7 +128,12 @@ def test_format_request_with_image(model, model_id): def test_format_request_with_tool_use(model, model_id): messages = [ - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "calculator", "input": '{"expression": "2+2"}'}}]} + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "tool-use-id-123", "name": "calculator", "input": '{"expression": "2+2"}'}} + ], + } ] tru_request = model.format_request(messages) @@ -321,9 +327,11 @@ def test_format_chunk_content_start_tool(model): event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}} + tool_use = tru_chunk["contentBlockStart"]["start"]["toolUse"] - assert tru_chunk == exp_chunk + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] != "calculator" + assert len(tool_use["toolUseId"]) > 0 def test_format_chunk_content_delta_text(model): @@ -499,24 +507,27 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): response = model.stream(messages) tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, - {"contentBlockStop": {}}, - {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - { - "metadata": { - "usage": {"inputTokens": 8, "outputTokens": 15, "totalTokens": 23}, - "metrics": {"latencyMs": 2.0}, - } - }, - ] - assert tru_events == exp_events + # Verify the tool use event has a unique ID (not equal to the tool name) + tool_start_event = tru_events[2] + tool_use = tool_start_event["contentBlockStart"]["start"]["toolUse"] + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] != "calculator" + + # Verify all other events + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + assert tru_events[3] == {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} + assert tru_events[4] == {"contentBlockStop": {}} + assert tru_events[5] == {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}} + assert tru_events[6] == {"contentBlockStop": {}} + assert tru_events[7] == {"messageStop": {"stopReason": "tool_use"}} + assert tru_events[8] == { + "metadata": { + "usage": {"inputTokens": 8, "outputTokens": 15, "totalTokens": 23}, + "metrics": {"latencyMs": 2.0}, + } + } expected_request = { "model": "m1", "messages": [{"role": "user", "content": "Calculate 2+2"}], @@ -625,3 +636,57 @@ def test_format_request_filters_location_source_document(model, caplog): user_message = formatted_messages[0] assert user_message["content"] == "analyze this document" assert "Location sources are not supported by Ollama" in caplog.text + + +def test_tool_use_id_is_unique_and_not_tool_name(model): + """Test that toolUseId is a unique UUID, not the tool name.""" + mock_function = unittest.mock.Mock() + mock_function.function.name = "calculator" + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} + + chunk1 = model.format_chunk(event) + chunk2 = model.format_chunk(event) + + tool_use1 = chunk1["contentBlockStart"]["start"]["toolUse"] + tool_use2 = chunk2["contentBlockStart"]["start"]["toolUse"] + + # toolUseId should not equal the tool name + assert tool_use1["toolUseId"] != "calculator" + assert tool_use2["toolUseId"] != "calculator" + + # toolUseId should be unique across calls + assert tool_use1["toolUseId"] != tool_use2["toolUseId"] + + # toolUseId should follow the tooluse_<24-hex> convention used by other providers + assert re.fullmatch(r"tooluse_[0-9a-f]{24}", tool_use1["toolUseId"]) + assert re.fullmatch(r"tooluse_[0-9a-f]{24}", tool_use2["toolUseId"]) + + # name should still be correct + assert tool_use1["name"] == "calculator" + assert tool_use2["name"] == "calculator" + + +def test_format_request_uses_tool_name_not_tool_use_id(model, model_id): + """Test that format_request uses the 'name' field, not 'toolUseId', for the function name.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "unique-id-abc-123", + "name": "calculator", + "input": '{"expression": "1+1"}', + } + } + ], + } + ] + + request = model.format_request(messages) + tool_call = request["messages"][0]["tool_calls"][0] + + # The function name in the request must come from "name", not "toolUseId" + assert tool_call["function"]["name"] == "calculator" + assert tool_call["function"]["name"] != "unique-id-abc-123" From da4c44e4a190b0750321263d9422f2f8bb07f030 Mon Sep 17 00:00:00 2001 From: Kien Pham <22681+kpx-dev@users.noreply.github.com> Date: Fri, 24 Apr 2026 06:29:16 -0700 Subject: [PATCH 244/279] feat(cache): add TTL support to CachePoint for prompt caching (#1660) Co-authored-by: Claude Sonnet 4.5 Co-authored-by: Murat Kaan Meral --- AGENTS.md | 1 + src/strands/models/bedrock.py | 8 +- tests/strands/models/test_bedrock.py | 47 ++++++ tests_integ/models/test_model_bedrock.py | 173 +++++++++++++++++++++++ 4 files changed, 227 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 8835b45c8..69f1b8e9a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -427,6 +427,7 @@ hatch test --all # Test all Python versions (3.10-3.13) - Use `moto` for mocking AWS services - Use `pytest.mark.asyncio` for async tests - Keep tests focused and independent +- Import packages at the top of the test files ## MCP Tasks (Experimental) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 7ff3024a8..e781b952e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -521,12 +521,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An """ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html if "cachePoint" in content: - return {"cachePoint": {"type": content["cachePoint"]["type"]}} + cache_point = content["cachePoint"] + result: dict[str, Any] = {"type": cache_point["type"]} + if "ttl" in cache_point: + result["ttl"] = cache_point["ttl"] + return {"cachePoint": result} # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html if "document" in content: document = content["document"] - result: dict[str, Any] = {} + result = {} # Handle required fields (all optional due to total=False) if "name" in document: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 384ee05e1..99a745b07 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2132,6 +2132,53 @@ def test_format_request_filters_cache_point_content_blocks(model, model_id): assert "extraField" not in cache_point_block +def test_format_request_preserves_cache_point_ttl(model, model_id): + """Test that format_request preserves the ttl field in cachePoint content blocks.""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + "ttl": "1h", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default", "ttl": "1h"} + assert cache_point_block == expected + assert cache_point_block["ttl"] == "1h" + + +def test_format_request_cache_point_without_ttl(model, model_id): + """Test that cache points work without ttl field (backward compatibility).""" + messages = [ + { + "role": "user", + "content": [ + { + "cachePoint": { + "type": "default", + } + }, + ], + } + ] + + formatted_request = model._format_request(messages) + + cache_point_block = formatted_request["messages"][0]["content"][0]["cachePoint"] + expected = {"type": "default"} + assert cache_point_block == expected + assert "ttl" not in cache_point_block + + def test_config_validation_warns_on_unknown_keys(bedrock_client, captured_warnings): """Test that unknown config keys emit a warning.""" BedrockModel(model_id="test-model", invalid_param="test") diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index e4ef727ce..4020ce35e 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -1,3 +1,6 @@ +import time +import uuid + import pydantic import pytest @@ -344,3 +347,173 @@ def test_multi_prompt_system_content(): agent = Agent(system_prompt=system_prompt_content, load_tools_from_directory=False) # just verifying there is no failure agent("Hello!") + + +def test_prompt_caching_with_5m_ttl(): + """Test prompt caching with 5 minute TTL and verify cache metrics. + + This test verifies: + 1. First call creates cache (cacheWriteInputTokens > 0) + 2. Second call reads from cache (cacheReadInputTokens > 0) + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock. + Older models (e.g. Claude Sonnet 4) reject the TTL field with a ValidationException. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + + # Use unique identifier to avoid cache conflicts between test runs + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache (cache write) + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred on first call + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call" + ) + + # Second call should use the cached content (cache read) + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred on second call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call" + ) + + +def test_prompt_caching_with_1h_ttl(): + """Test prompt caching with 1 hour TTL and verify cache metrics. + + Uses Claude Haiku 4.5 which supports 1hr TTL. + Uses unique content per test run to avoid cache conflicts with concurrent CI runs. + Even with 1hr TTL, unique content ensures cache entries don't interfere across tests. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + + # Use timestamp to ensure unique content per test run (avoids CI conflicts) + unique_id = str(int(time.time() * 1000000)) # microsecond timestamp + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + # First call should create the cache + result1 = agent("What is 2+2?") + assert len(str(result1)) > 0 + + # Verify cache write occurred + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 on first call with 1h TTL" + ) + + # Second call should use the cached content + result2 = agent("What is 3+3?") + assert len(str(result2)) > 0 + + # Verify cache read occurred + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on second call with 1h TTL" + ) + + +def test_prompt_caching_with_ttl_in_messages(): + """Test prompt caching with TTL in message content and verify cache metrics. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock. + Older models (e.g. Claude Sonnet 4) reject the TTL field with a ValidationException. + """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + agent = Agent(model=model, load_tools_from_directory=False) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_text = f"Important context for test {unique_id}: " + ("This is critical information. " * 1000) + + content_with_cache = [ + {"text": large_text}, + {"cachePoint": {"type": "default", "ttl": "5m"}}, + {"text": "Based on the context above, what is 5+5?"}, + ] + + # First call creates cache + result1 = agent(content_with_cache) + assert len(str(result1)) > 0 + + # Verify cache write in message content + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 when caching message content" + ) + + # Subsequent call should use cache + result2 = agent("What about 10+10?") + assert len(str(result2)) > 0 + + # Verify cache read on subsequent call + assert result2.metrics.accumulated_usage.get("cacheReadInputTokens", 0) > 0, ( + "Expected cacheReadInputTokens > 0 on subsequent call" + ) + + +def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model): + """Test that prompt caching works without TTL (backward compatibility). + + Verifies that cache points work correctly when TTL is not specified, + maintaining backward compatibility with existing code. + """ + unique_id = str(uuid.uuid4()) + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200) + + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default"}}, # No TTL specified + {"text": "You are a helpful assistant."}, + ] + + agent = Agent( + model=non_streaming_model, + system_prompt=system_prompt_with_cache, + load_tools_from_directory=False, + ) + + result = agent("Hello!") + assert len(str(result)) > 0 + + # Verify cache write occurred even without TTL + assert result.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 even without TTL specified" + ) From 22b3aaf350ce4672d8fceb6ee1e90cd82b928768 Mon Sep 17 00:00:00 2001 From: Osman-AGI Date: Fri, 24 Apr 2026 06:49:33 -0700 Subject: [PATCH 245/279] =?UTF-8?q?fix:=20use=20non-interactive=20flag=20f?= =?UTF-8?q?or=20Nova=20Sonic=20history=20and=20system=20promp=E2=80=A6=20(?= =?UTF-8?q?#2188)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../experimental/bidi/models/nova_sonic.py | 99 +++++++++++++++---- 1 file changed, 79 insertions(+), 20 deletions(-) diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index d836bde49..8ad5b2a83 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -96,6 +96,9 @@ NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} NOVA_TOOL_CONFIG = {"mediaType": "application/json"} +_MAX_HISTORY_MESSAGE_BYTES = 50 * 1024 # 50KB per message +_MAX_HISTORY_TOTAL_BYTES = 200 * 1024 # 200KB total history + class BidiNovaSonicModel(BidiModel): """Nova Sonic implementation for bidirectional streaming. @@ -726,7 +729,7 @@ def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ - self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_content_start_event(content_name, "SYSTEM", interactive=False), self._get_text_input_event(content_name, system_prompt or ""), self._get_content_end_event(content_name), ] @@ -737,42 +740,98 @@ def _get_message_history_events(self, messages: Messages) -> list[str]: Converts agent message history to Nova Sonic format following the contentStart/textInput/contentEnd pattern for each message. + History messages are sent as non-interactive (interactive=False) so Nova Sonic + treats them as prior context rather than new inputs requiring a response. + + Individual messages are truncated to 50KB and total history is capped + at 200KB. When the limit is reached, the oldest messages are dropped + to prioritize recent conversation context. + Args: messages: List of conversation messages with role and content. Returns: List of JSON event strings for Nova Sonic. """ - events = [] + max_message_bytes = _MAX_HISTORY_MESSAGE_BYTES + max_total_bytes = _MAX_HISTORY_TOTAL_BYTES - for message in messages: - role = message["role"].upper() # Convert to ASSISTANT or USER + # First pass: extract and truncate text from each message, walking backwards + # to prioritize recent messages when the total size limit is hit + prepared: list[tuple[str, str]] = [] # (role, text) + total_bytes = 0 + + for message in reversed(messages): + role = message["role"].upper() content_blocks = message.get("content", []) - # Extract text content from content blocks text_parts = [] for block in content_blocks: if "text" in block: text_parts.append(block["text"]) - # Combine all text parts - if text_parts: - combined_text = "\n".join(text_parts) - content_name = str(uuid.uuid4()) - - # Add contentStart, textInput, and contentEnd events - events.extend( - [ - self._get_text_content_start_event(content_name, role), - self._get_text_input_event(content_name, combined_text), - self._get_content_end_event(content_name), - ] + if not text_parts: + continue + + combined_text = "\n".join(text_parts) + + # Truncate individual message + encoded = combined_text.encode("utf-8") + if len(encoded) > max_message_bytes: + encoded = encoded[:max_message_bytes] + combined_text = encoded.decode("utf-8", errors="ignore") + encoded = combined_text.encode("utf-8") + + msg_bytes = len(encoded) + + if total_bytes + msg_bytes > max_total_bytes: + logger.debug( + "total_bytes=<%d>, msg_bytes=<%d>, max_total_bytes=<%d> | dropping older messages to fit limit", + total_bytes, + msg_bytes, + max_total_bytes, ) + break + + total_bytes += msg_bytes + prepared.append((role, combined_text)) + + # Reverse back to chronological order + prepared.reverse() + + # Ensure the first message is from the user role — drop leading assistant messages + while prepared and prepared[0][0] != "USER": + dropped_role, dropped_text = prepared.pop(0) + logger.debug( + "role=<%s>, text_preview=<%s> | dropping leading non-user message from history", + dropped_role, + dropped_text[:100], + ) + + logger.debug("prepared_count=<%d>, total_bytes=<%d> | final history after trimming", len(prepared), total_bytes) + + # Second pass: build events + events: list[str] = [] + for role, text in prepared: + content_name = str(uuid.uuid4()) + events.extend( + [ + self._get_text_content_start_event(content_name, role, interactive=False), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + ) return events - def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: - """Generate text content start event.""" + def _get_text_content_start_event(self, content_name: str, role: str = "USER", interactive: bool = True) -> str: + """Generate text content start event. + + Args: + content_name: Unique identifier for this content block. + role: Message role (USER, ASSISTANT, SYSTEM). + interactive: Whether this is a live input (True) or history context (False). + """ return json.dumps( { "event": { @@ -781,7 +840,7 @@ def _get_text_content_start_event(self, content_name: str, role: str = "USER") - "contentName": content_name, "type": "TEXT", "role": role, - "interactive": True, + "interactive": interactive, "textInputConfiguration": NOVA_TEXT_CONFIG, } } From 609723a579808efc07af1b0d73e5a985187ca96a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 09:51:34 -0400 Subject: [PATCH 246/279] ci: update litellm requirement from <=1.82.6,>=1.75.9 to >=1.75.9,<=1.83.13 (#2197) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e1ab0d7d4..d0f1074e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] -litellm = ["litellm>=1.75.9,<=1.82.6", "openai>=1.68.0,<3.0.0"] +litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] From 5b6aa56a305e52dd2802c5a38139fd9fbbfeddf0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 09:52:46 -0400 Subject: [PATCH 247/279] ci: update pre-commit requirement from <4.6.0,>=3.2.0 to >=3.2.0,<4.7.0 (#2185) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d0f1074e4..83a7bbf4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ dev = [ "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", "mypy>=1.15.0,<2.0.0", - "pre-commit>=3.2.0,<4.6.0", + "pre-commit>=3.2.0,<4.7.0", "pytest>=9.0.0,<10.0.0", "pytest-cov>=7.0.0,<8.0.0", "pytest-asyncio>=1.0.0,<1.4.0", @@ -174,7 +174,7 @@ features = ["all"] dependencies = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", - "pre-commit>=3.2.0,<4.6.0", + "pre-commit>=3.2.0,<4.7.0", ] From 33b25cbd30bee039292c4104a9439692b023e816 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:46:50 -0400 Subject: [PATCH 248/279] feat: large tool result offload (#2162) --- AGENTS.md | 3 +- .../context_offloader/__init__.py | 46 ++ .../context_offloader/plugin.py | 336 +++++++++++ .../context_offloader/storage.py | 373 ++++++++++++ tests/strands/models/test_model.py | 39 +- .../context_offloader/__init__.py | 0 .../context_offloader/test_plugin.py | 539 ++++++++++++++++++ .../context_offloader/test_storage.py | 266 +++++++++ 8 files changed, 1585 insertions(+), 17 deletions(-) create mode 100644 src/strands/vended_plugins/context_offloader/__init__.py create mode 100644 src/strands/vended_plugins/context_offloader/plugin.py create mode 100644 src/strands/vended_plugins/context_offloader/storage.py create mode 100644 tests/strands/vended_plugins/context_offloader/__init__.py create mode 100644 tests/strands/vended_plugins/context_offloader/test_plugin.py create mode 100644 tests/strands/vended_plugins/context_offloader/test_storage.py diff --git a/AGENTS.md b/AGENTS.md index 69f1b8e9a..0b877ea98 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -141,7 +141,8 @@ strands-agents/ │ │ │ ├── context_providers/ # Context data providers (e.g., ledger) │ │ │ ├── core/ # Base classes, actions, context │ │ │ └── handlers/ # Handler implementations (e.g., LLM) -│ │ └── skills/ # AgentSkills.io integration (Skill, AgentSkills) +│ │ ├── skills/ # AgentSkills.io integration (Skill, AgentSkills) +│ │ └── context_offloader/ # Large tool result offloading plugin │ │ │ ├── experimental/ # Experimental features (API may change) │ │ ├── agent_config.py # Experimental agent config diff --git a/src/strands/vended_plugins/context_offloader/__init__.py b/src/strands/vended_plugins/context_offloader/__init__.py new file mode 100644 index 000000000..01ca6f1fc --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/__init__.py @@ -0,0 +1,46 @@ +"""Context offloader plugin for Strands Agents. + +This module provides the ContextOffloader plugin which intercepts oversized +tool results, persists each content block to a storage backend, and replaces +the in-context result with a truncated preview and per-block references. + +Example Usage: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ( + ContextOffloader, + InMemoryStorage, + FileStorage, + ) + + # In-memory storage + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + + # File storage with custom thresholds + agent = Agent(plugins=[ + ContextOffloader( + storage=FileStorage("./artifacts"), + max_result_tokens=5_000, + preview_tokens=2_000, + ) + ]) + ``` +""" + +from .plugin import ContextOffloader +from .storage import ( + FileStorage, + InMemoryStorage, + S3Storage, + Storage, +) + +__all__ = [ + "ContextOffloader", + "FileStorage", + "InMemoryStorage", + "S3Storage", + "Storage", +] diff --git a/src/strands/vended_plugins/context_offloader/plugin.py b/src/strands/vended_plugins/context_offloader/plugin.py new file mode 100644 index 000000000..0072d3934 --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/plugin.py @@ -0,0 +1,336 @@ +"""ContextOffloader plugin for managing large tool outputs. + +This module provides the ContextOffloader plugin that intercepts oversized +tool results, persists each content block to a storage backend, and replaces +the in-context result with a truncated preview and per-block references. + +Example: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ( + ContextOffloader, + InMemoryStorage, + FileStorage, + ) + + # In-memory storage + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + + # File storage with custom thresholds and retrieval tool enabled + agent = Agent(plugins=[ + ContextOffloader( + storage=FileStorage("./artifacts"), + max_result_tokens=5_000, + preview_tokens=2_000, + include_retrieval_tool=True, + ) + ]) + ``` +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +from ...hooks.events import AfterToolCallEvent +from ...models.model import _get_encoding +from ...plugins import Plugin, hook +from ...tools.decorator import tool +from ...types.content import Message +from ...types.tools import ToolContext, ToolResult, ToolResultContent +from .storage import Storage + +if TYPE_CHECKING: + from ...agent.agent import Agent + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_RESULT_TOKENS = 2_500 +"""Default token threshold above which tool results are offloaded.""" + +_DEFAULT_PREVIEW_TOKENS = 1_000 +"""Default number of tokens to keep as a preview in context.""" + +_CHARS_PER_TOKEN = 4 +"""Approximate characters per token, fallback for preview slicing without tiktoken.""" + + +class ContextOffloader(Plugin): + """Plugin that offloads oversized tool results to reduce context consumption. + + When a tool result exceeds the configured token threshold, this plugin + stores each content block individually to a storage backend and replaces + the in-context result with a truncated text preview plus per-block references. + + Token estimation uses the agent's model ``count_tokens`` method, which + leverages tiktoken when available and falls back to character-based heuristics. + + Content type handling: + + - **Text**: stored as ``text/plain``, replaced with a preview + - **JSON**: stored as ``application/json``, replaced with a preview + - **Image**: stored in its native format (e.g., ``image/png``), replaced with a + placeholder showing format and size + - **Document**: stored in its native format (e.g., ``application/pdf``), replaced + with a placeholder showing format, name, and size + - **Unknown types**: passed through unchanged + + This operates proactively at tool execution time via ``AfterToolCallEvent``, + before the result enters the conversation — unlike ``SlidingWindowConversationManager`` + which truncates reactively after context overflow. + + Args: + storage: Backend for storing offloaded content (required). + max_result_tokens: Offload results whose estimated token count exceeds this threshold. + preview_tokens: Number of tokens to keep as a text preview in context. + include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool. + Defaults to False. + + Example: + ```python + from strands import Agent + from strands.vended_plugins.context_offloader import ContextOffloader, InMemoryStorage + + agent = Agent(plugins=[ + ContextOffloader(storage=InMemoryStorage()) + ]) + ``` + """ + + name = "context_offloader" + + def __init__( + self, + storage: Storage, + max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS, + preview_tokens: int = _DEFAULT_PREVIEW_TOKENS, + *, + include_retrieval_tool: bool = False, + ) -> None: + """Initialize the ContextOffloader plugin. + + Args: + storage: Backend for storing offloaded content. + max_result_tokens: Offload results whose estimated token count exceeds this + threshold. Defaults to ``_DEFAULT_MAX_RESULT_TOKENS`` (2,500). + preview_tokens: Number of tokens to keep as a text preview in context. + Uses tiktoken for exact slicing when available, falls back to + chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000). + include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` + tool so the agent can fetch offloaded content. Defaults to False. + + Raises: + ValueError: If max_result_tokens is not positive, preview_tokens is negative, + or preview_tokens >= max_result_tokens. + """ + if max_result_tokens <= 0: + raise ValueError("max_result_tokens must be positive") + if preview_tokens < 0: + raise ValueError("preview_tokens must be non-negative") + if preview_tokens >= max_result_tokens: + raise ValueError("preview_tokens must be less than max_result_tokens") + + self._storage = storage + self._max_result_tokens = max_result_tokens + self._preview_tokens = preview_tokens + self._include_retrieval_tool = include_retrieval_tool + super().__init__() + + def init_agent(self, agent: Agent) -> None: + """Conditionally register the retrieval tool.""" + if not self._include_retrieval_tool: + # Remove the auto-discovered retrieval tool + self._tools = [t for t in self._tools if t.tool_name != "retrieve_offloaded_content"] + + @tool(context=True) + def retrieve_offloaded_content( + self, + reference: str, + tool_context: ToolContext, + ) -> dict | str: + """Retrieve offloaded content by reference. + + Use this tool when you see a placeholder with a reference (ref: ...) + and need the full content. + + Args: + reference: The reference string from the offload placeholder. + tool_context: Injected by the framework. Not user-facing. + """ + try: + content_bytes, content_type = self._storage.retrieve(reference) + except KeyError: + return f"Error: reference not found: {reference}" + + if content_type.startswith("text/"): + return content_bytes.decode("utf-8") + + if content_type == "application/json": + return {"status": "success", "content": [{"json": json.loads(content_bytes)}]} + + if content_type.startswith("image/"): + img_format = content_type.split("/")[-1] + return { + "status": "success", + "content": [{"image": {"format": img_format, "source": {"bytes": content_bytes}}}], + } + + if content_type.startswith("application/"): + doc_format = content_type.split("/")[-1] + doc_block = {"format": doc_format, "name": reference, "source": {"bytes": content_bytes}} + return {"status": "success", "content": [{"document": doc_block}]} + + return content_bytes.decode("utf-8", errors="replace") + + @hook + async def _handle_tool_result(self, event: AfterToolCallEvent) -> None: + """Intercept oversized tool results, offload per-block, and replace with preview.""" + if event.cancel_message is not None: + return + + if self._include_retrieval_tool and event.tool_use.get("name") == self.retrieve_offloaded_content.tool_name: + return + + result = event.result + content = result["content"] + tool_use_id = event.tool_use["toolUseId"] + + # Estimate token count by wrapping the tool result as a message for count_tokens + tool_result_message: Message = {"role": "user", "content": [{"toolResult": result}]} + token_count = await event.agent.model.count_tokens([tool_result_message]) + + if token_count <= self._max_result_tokens: + return + + # Build text preview from text+JSON blocks. + # Empty text blocks are intentionally excluded — they add no content value. + text_preview_parts: list[str] = [] + for block in content: + if block.get("text"): + text_preview_parts.append(block["text"]) + elif "json" in block: + text_preview_parts.append(json.dumps(block["json"], indent=2)) + + full_text = "\n".join(text_preview_parts) if text_preview_parts else "" + + # Store each content block individually + references: list[tuple[str, str, str]] = [] # (ref, content_type, description) + try: + for i, block in enumerate(content): + key = f"{tool_use_id}_{i}" + if block.get("text"): + ref = self._storage.store(key, block["text"].encode("utf-8"), "text/plain") + references.append((ref, "text/plain", f"text, {len(block['text']):,} chars")) + elif "json" in block: + json_bytes = json.dumps(block["json"], indent=2).encode("utf-8") + ref = self._storage.store(key, json_bytes, "application/json") + references.append((ref, "application/json", f"json, {len(json_bytes):,} bytes")) + elif "image" in block: + image = block["image"] + img_format = image.get("format", "unknown") + img_bytes = image.get("source", {}).get("bytes", b"") + if img_bytes: + ref = self._storage.store(key, img_bytes, f"image/{img_format}") + references.append((ref, f"image/{img_format}", f"image/{img_format}, {len(img_bytes):,} bytes")) + else: + references.append(("", f"image/{img_format}", f"image/{img_format}, 0 bytes")) + elif "document" in block: + doc = block["document"] + doc_format = doc.get("format", "unknown") + doc_name = doc.get("name", "unknown") + doc_bytes = doc.get("source", {}).get("bytes", b"") + if doc_bytes: + ref = self._storage.store(key, doc_bytes, f"application/{doc_format}") + references.append((ref, f"application/{doc_format}", f"{doc_name}, {len(doc_bytes):,} bytes")) + else: + references.append(("", f"application/{doc_format}", f"{doc_name}, 0 bytes")) + except Exception: + logger.warning( + "tool_use_id=<%s> | failed to offload tool result, keeping original", + tool_use_id, + exc_info=True, + ) + return + + logger.debug( + "tool_use_id=<%s>, blocks=<%d>, tokens=<%d> | tool result offloaded", + tool_use_id, + len(references), + token_count, + ) + + # Build preview text — use tiktoken for exact slicing when available + preview = self._slice_preview(full_text) if full_text else "" + ref_lines = "\n".join(f" {ref} ({desc})" for ref, _, desc in references if ref) + + guidance = ( + "Tool result was offloaded to external storage due to size.\n" + "Use the preview below to answer if possible.\n" + "Use your available tools to selectively access the data you need." + ) + if self._include_retrieval_tool: + guidance += "\nYou can also use retrieve_offloaded_content with a reference to get the full content." + + preview_text = ( + f"[Offloaded: {len(content)} blocks, ~{token_count:,} tokens]\n" + f"{guidance}\n\n" + f"{preview}\n\n" + f"[Stored references:]\n{ref_lines}" + ) + + # Build new content with preview + placeholders for non-text blocks + new_content: list[ToolResultContent] = [ToolResultContent(text=preview_text)] + for i, block in enumerate(content): + ref = references[i][0] if i < len(references) else "" + if "text" in block or "json" in block: + continue + elif "image" in block: + image = block["image"] + img_format = image.get("format", "unknown") + img_bytes = image.get("source", {}).get("bytes", b"") + placeholder = f"[image: {img_format}, {len(img_bytes) if img_bytes else 0} bytes" + if ref: + placeholder += f" | ref: {ref}" + placeholder += "]" + new_content.append(ToolResultContent(text=placeholder)) + elif "document" in block: + doc = block["document"] + doc_format = doc.get("format", "unknown") + doc_name = doc.get("name", "unknown") + doc_bytes = doc.get("source", {}).get("bytes", b"") + placeholder = f"[document: {doc_format}, {doc_name}, {len(doc_bytes) if doc_bytes else 0} bytes" + if ref: + placeholder += f" | ref: {ref}" + placeholder += "]" + new_content.append(ToolResultContent(text=placeholder)) + else: + new_content.append(block) + + event.result = ToolResult( + toolUseId=result["toolUseId"], + status=result["status"], + content=new_content, + ) + + def _slice_preview(self, text: str) -> str: + """Slice text to approximately preview_tokens. + + Uses tiktoken for exact token-level slicing when available, + falls back to characters (tokens * 4) otherwise. + + Args: + text: The full text to slice. + + Returns: + The preview text. + """ + encoding = _get_encoding() + if encoding is not None: + tokens = encoding.encode(text) + preview: str = encoding.decode(tokens[: self._preview_tokens]) + return preview + return text[: self._preview_tokens * _CHARS_PER_TOKEN] diff --git a/src/strands/vended_plugins/context_offloader/storage.py b/src/strands/vended_plugins/context_offloader/storage.py new file mode 100644 index 000000000..a12055a2e --- /dev/null +++ b/src/strands/vended_plugins/context_offloader/storage.py @@ -0,0 +1,373 @@ +"""Storage backends for offloaded tool result content. + +This module defines the Storage protocol and provides three built-in +implementations: file-based, in-memory, and S3 storage. Each content block +from a tool result is stored individually with its content type preserved. + +Example: + ```python + from strands.vended_plugins.context_offloader import ( + FileStorage, + InMemoryStorage, + S3Storage, + ) + + # File-based storage + storage = FileStorage(artifact_dir="./artifacts") + ref = storage.store("tool_123_0", b"large output content...", "text/plain") + content, content_type = storage.retrieve(ref) + + # In-memory storage (useful for testing and serverless) + storage = InMemoryStorage() + + # S3 storage + storage = S3Storage(bucket="my-bucket", prefix="artifacts/") + ``` +""" + +import json +import re +import threading +import time +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +import boto3 +from botocore.config import Config as BotocoreConfig +from botocore.exceptions import ClientError + + +def _sanitize_id(raw_id: str) -> str: + """Sanitize an ID for safe use in filenames and object keys. + + Replaces path separators, parent directory references, and other + unsafe characters with underscores. + + Args: + raw_id: The raw ID string. + + Returns: + A sanitized string safe for use in filenames. + """ + sanitized = raw_id.replace("..", "_").replace("/", "_").replace("\\", "_") + sanitized = re.sub(r"[^\w\-.]", "_", sanitized) + return sanitized + + +@runtime_checkable +class Storage(Protocol): + """Backend for storing and retrieving offloaded content blocks. + + Each content block from a tool result is stored individually with its + content type preserved. The SDK ships three built-in implementations: + ``InMemoryStorage``, ``FileStorage``, and ``S3Storage``. Implement this + protocol to create custom storage backends (e.g., Redis, DynamoDB). + + Lifecycle: + This protocol intentionally does not include eviction or deletion methods. + Stored content accumulates for the lifetime of the storage instance. For + long-running agents, create a new storage instance per session or use a + backend with built-in lifecycle management (e.g., S3 lifecycle policies). + """ + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content and return a reference identifier. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content (e.g., "text/plain", + "application/json", "image/png", "application/pdf"). + + Returns: + A reference string that can be used to retrieve the content later. + """ + ... + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve stored content by reference. + + Args: + reference: The reference returned by a previous store() call. + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the reference is not found. + """ + ... + + +class FileStorage: + """Store offloaded content as files on disk. + + Files are written to the configured artifact directory with unique names. + File extensions are derived from the content type. A ``.metadata.json`` + sidecar file tracks content types so they survive process restarts. + + Args: + artifact_dir: Directory path where artifact files will be stored. + """ + + _METADATA_FILE = ".metadata.json" + + def __init__(self, artifact_dir: str = "./artifacts") -> None: + """Initialize file-based storage. + + Args: + artifact_dir: Directory path where artifact files will be stored. + """ + self._artifact_dir = Path(artifact_dir) + self._counter: int = 0 + self._lock = threading.Lock() + self._content_types: dict[str, str] = self._load_metadata() + + @staticmethod + def _extension_for(content_type: str) -> str: + """Return a file extension for the given content type.""" + if content_type == "text/plain": + return ".txt" + return f".{content_type.split('/')[-1]}" + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content as a file and return the filename as reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + The filename (not full path) used as the reference. + """ + self._artifact_dir.mkdir(parents=True, exist_ok=True) + + sanitized_key = _sanitize_id(key) + timestamp_ms = int(time.time() * 1000) + ext = self._extension_for(content_type) + with self._lock: + self._counter += 1 + counter = self._counter + filename = f"{timestamp_ms}_{counter}_{sanitized_key}{ext}" + self._content_types[filename] = content_type + self._save_metadata() + + file_path = self._artifact_dir / filename + file_path.write_bytes(content) + + return filename + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from a stored file. + + Args: + reference: The filename reference returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the file does not exist. + """ + file_path = (self._artifact_dir / reference).resolve() + if not file_path.is_relative_to(self._artifact_dir.resolve()): + raise KeyError(f"Reference not found: {reference}") + if not file_path.is_file(): + raise KeyError(f"Reference not found: {reference}") + content_type = self._content_types.get(reference, "application/octet-stream") + return file_path.read_bytes(), content_type + + def _load_metadata(self) -> dict[str, str]: + """Load content type metadata from the sidecar file.""" + metadata_path = self._artifact_dir / self._METADATA_FILE + if metadata_path.is_file(): + try: + result: dict[str, str] = json.loads(metadata_path.read_text(encoding="utf-8")) + return result + except (json.JSONDecodeError, OSError): + return {} + return {} + + def _save_metadata(self) -> None: + """Save content type metadata to the sidecar file.""" + metadata_path = self._artifact_dir / self._METADATA_FILE + metadata_path.write_text(json.dumps(self._content_types), encoding="utf-8") + + +class InMemoryStorage: + """Store offloaded content in memory. + + Useful for testing and serverless environments where disk access + is not available or not desired. Thread-safe. + + Note: + Content accumulates for the lifetime of this instance. For long-running + agents, consider creating a new instance per session or switching to + ``FileStorage`` or ``S3Storage`` for persistent storage with external + lifecycle management. + """ + + def __init__(self) -> None: + """Initialize in-memory storage.""" + self._store: dict[str, tuple[bytes, str]] = {} + self._counter: int = 0 + self._lock = threading.Lock() + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content in memory and return a reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + A unique reference string. + """ + with self._lock: + self._counter += 1 + reference = f"mem_{self._counter}_{key}" + self._store[reference] = (content, content_type) + return reference + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from memory. + + Args: + reference: The reference returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the reference is not found. + """ + with self._lock: + if reference not in self._store: + raise KeyError(f"Reference not found: {reference}") + return self._store[reference] + + def clear(self) -> None: + """Remove all stored content. + + Call this to free memory when offloaded results are no longer needed, + e.g., between sessions or after an invocation completes. + """ + with self._lock: + self._store.clear() + + +class S3Storage: + """Store offloaded content in Amazon S3. + + Objects are stored with unique keys under the configured prefix. + Content type is preserved as S3 object metadata. + + Args: + bucket: S3 bucket name. + prefix: S3 key prefix for organizing stored artifacts. + boto_session: Optional boto3 session. If not provided, a new session + is created using the given region_name. + boto_client_config: Optional botocore client configuration. + region_name: AWS region. Used only when boto_session is not provided. + + Example: + ```python + from strands.vended_plugins.context_offloader import S3Storage + + storage = S3Storage( + bucket="my-agent-artifacts", + prefix="tool-results/", + ) + ``` + """ + + def __init__( + self, + bucket: str, + prefix: str = "", + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + ) -> None: + """Initialize S3-based storage. + + Args: + bucket: S3 bucket name. + prefix: S3 key prefix for organizing stored artifacts. + boto_session: Optional boto3 session. If not provided, a new session + is created using the given region_name. + boto_client_config: Optional botocore client configuration. + region_name: AWS region. Used only when boto_session is not provided. + """ + self._bucket = bucket + self._prefix = prefix.strip("/") + if self._prefix: + self._prefix += "/" + + session = boto_session or boto3.Session(region_name=region_name) + + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + new_user_agent = f"{existing_user_agent} strands-agents" if existing_user_agent else "strands-agents" + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + + self._client: Any = session.client(service_name="s3", config=client_config) + self._counter: int = 0 + self._lock = threading.Lock() + + def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: + """Store content as an S3 object and return the object key as reference. + + Args: + key: A unique key for this content block. + content: The raw content bytes to store. + content_type: MIME type of the content. + + Returns: + The S3 object key used as the reference. + + Raises: + botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket + does not exist, permission denied). + """ + sanitized_key = _sanitize_id(key) + timestamp_ms = int(time.time() * 1000) + with self._lock: + self._counter += 1 + counter = self._counter + s3_key = f"{self._prefix}{timestamp_ms}_{counter}_{sanitized_key}" + + self._client.put_object( + Bucket=self._bucket, + Key=s3_key, + Body=content, + ContentType=content_type, + ) + + return s3_key + + def retrieve(self, reference: str) -> tuple[bytes, str]: + """Retrieve content from an S3 object. + + Args: + reference: The S3 object key returned by store(). + + Returns: + A tuple of (content bytes, content type). + + Raises: + KeyError: If the object does not exist. + """ + try: + response = self._client.get_object(Bucket=self._bucket, Key=reference) + content: bytes = response["Body"].read() + content_type: str = response.get("ContentType", "application/octet-stream") + return content, content_type + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + raise KeyError(f"Reference not found: {reference}") from e + raise diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 11d4c10b9..2c685b43b 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -509,7 +509,7 @@ async def test_count_tokens_all_inputs(model): assert result == 31 -def test_get_encoding_falls_back_without_tiktoken(monkeypatch): +def test__get_encoding_falls_back_without_tiktoken(monkeypatch): """Test that _get_encoding returns None and count_tokens falls back to heuristic.""" import strands.models.model as model_module @@ -550,15 +550,21 @@ def test_all_content_types(self): messages = [ {"role": "user", "content": [{"text": "hello world!"}]}, - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, - {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, - {"guardContent": {"text": {"text": "Filtered."}}}, - {"citationsContent": {"content": [{"text": "Citation."}]}}, - ]}, - {"role": "user", "content": [ - {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"q": "test"}}}, + {"reasoningContent": {"reasoningText": {"text": "Let me think."}}}, + {"guardContent": {"text": {"text": "Filtered."}}}, + {"citationsContent": {"content": [{"text": "Citation."}]}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "1", "content": [{"text": "tool output here"}]}}, + ], + }, ] result = _estimate_tokens_with_heuristic( messages=messages, @@ -574,9 +580,12 @@ def test_non_serializable_inputs(self): result = _estimate_tokens_with_heuristic( messages=[ - {"role": "assistant", "content": [ - {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, - ]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "1", "name": "my_tool", "input": {"data": b"bytes"}}}, + ], + }, ], tool_specs=[{"name": "t", "inputSchema": {"json": {"default": b"bytes"}}}], ) @@ -598,9 +607,7 @@ def _block_tiktoken(name, *args, **kwargs): monkeypatch.setattr("builtins.__import__", _block_tiktoken) try: - result = await model.count_tokens( - messages=[{"role": "user", "content": [{"text": "hello world!"}]}] - ) + result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) assert result == 3 # ceil(12 / 4) finally: model_module._get_encoding.cache_clear() diff --git a/tests/strands/vended_plugins/context_offloader/__init__.py b/tests/strands/vended_plugins/context_offloader/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/vended_plugins/context_offloader/test_plugin.py b/tests/strands/vended_plugins/context_offloader/test_plugin.py new file mode 100644 index 000000000..528d1f006 --- /dev/null +++ b/tests/strands/vended_plugins/context_offloader/test_plugin.py @@ -0,0 +1,539 @@ +"""Tests for the ContextOffloader plugin.""" + +import json +import logging +import math +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from strands.hooks.events import AfterToolCallEvent +from strands.types.tools import ToolContext, ToolUse +from strands.vended_plugins.context_offloader import ( + ContextOffloader, + InMemoryStorage, +) + + +@pytest.fixture +def storage(): + return InMemoryStorage() + + +@pytest.fixture +def plugin(storage): + return ContextOffloader( + storage=storage, + max_result_tokens=25, + preview_tokens=10, + ) + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.model = MagicMock() + agent.model.count_tokens = AsyncMock(side_effect=_heuristic_count_tokens) + return agent + + +async def _heuristic_count_tokens(messages, **kwargs): + """Heuristic token counter for tests: chars / 4.""" + total = 0 + for msg in messages: + for block in msg.get("content", []): + if "toolResult" in block: + for content in block["toolResult"].get("content", []): + if "text" in content: + total += math.ceil(len(content["text"]) / 4) + elif "json" in content: + total += math.ceil(len(json.dumps(content["json"])) / 4) + elif "text" in block: + total += math.ceil(len(block["text"]) / 4) + return total + + +def _make_event(agent, text_content, status="success", tool_use_id="tool_123", cancel_message=None): + """Helper to create an AfterToolCallEvent with content.""" + if isinstance(text_content, str): + content = [{"text": text_content}] + else: + content = text_content + + result = { + "toolUseId": tool_use_id, + "status": status, + "content": content, + } + tool_use = {"toolUseId": tool_use_id, "name": "test_tool", "input": {}} + + return AfterToolCallEvent( + agent=agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + cancel_message=cancel_message, + ) + + +class TestContextOffloader: + def test_plugin_name(self, plugin): + assert plugin.name == "context_offloader" + + def test_hooks_auto_discovered(self, plugin): + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "_handle_tool_result" + + def test_raises_on_non_positive_max_result_tokens(self): + with pytest.raises(ValueError, match="max_result_tokens must be positive"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=0) + with pytest.raises(ValueError, match="max_result_tokens must be positive"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=-1) + + def test_raises_on_negative_preview_tokens(self): + with pytest.raises(ValueError, match="preview_tokens must be non-negative"): + ContextOffloader(storage=InMemoryStorage(), preview_tokens=-1) + + def test_raises_on_preview_tokens_gte_max_result_tokens(self): + with pytest.raises(ValueError, match="preview_tokens must be less than max_result_tokens"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=100, preview_tokens=100) + with pytest.raises(ValueError, match="preview_tokens must be less than max_result_tokens"): + ContextOffloader(storage=InMemoryStorage(), max_result_tokens=100, preview_tokens=200) + + @pytest.mark.asyncio + async def test_offloads_oversized_text(self, plugin, storage, mock_agent): + large_text = "a" * 200 + event = _make_event(mock_agent, large_text) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert "[Offloaded:" in result_text + # Preview should be shorter than the full text + assert len(result_text) < len(large_text) + 500 # preview + metadata < original + overhead + + # Verify stored content + assert len(storage._store) == 1 + ref = list(storage._store.keys())[0] + content, content_type = storage.retrieve(ref) + assert content == large_text.encode("utf-8") + assert content_type == "text/plain" + + @pytest.mark.asyncio + async def test_preserves_status_and_tool_use_id(self, plugin, mock_agent): + event = _make_event(mock_agent, "x" * 200, status="error", tool_use_id="my_tool_456") + + await plugin._handle_tool_result(event) + + assert event.result["status"] == "error" + assert event.result["toolUseId"] == "my_tool_456" + + @pytest.mark.asyncio + async def test_under_threshold_passes_through(self, plugin, mock_agent): + small_text = "x" * 50 # 12.5 tokens, under 25 + event = _make_event(mock_agent, small_text) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_at_threshold_passes_through(self, plugin, mock_agent): + exact_text = "x" * 100 # exactly 25 tokens + event = _make_event(mock_agent, exact_text) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_skips_cancelled_tool_calls(self, plugin, mock_agent): + large_text = "x" * 200 + event = _make_event(mock_agent, large_text, cancel_message="tool cancelled by user") + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_skips_retrieve_tool_results_when_enabled(self, storage, mock_agent): + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + large_text = "x" * 200 + result = {"toolUseId": "tool_123", "status": "success", "content": [{"text": large_text}]} + tool_use = {"toolUseId": "tool_123", "name": plugin.retrieve_offloaded_content.tool_name, "input": {}} + event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + ) + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == large_text + + @pytest.mark.asyncio + async def test_does_not_skip_retrieve_tool_when_disabled(self, plugin, storage, mock_agent): + large_text = "x" * 200 + result = {"toolUseId": "tool_123", "status": "success", "content": [{"text": large_text}]} + tool_use = {"toolUseId": "tool_123", "name": "retrieve_offloaded_content", "input": {}} + event = AfterToolCallEvent( + agent=mock_agent, + selected_tool=None, + tool_use=tool_use, + invocation_state={}, + result=result, + ) + await plugin._handle_tool_result(event) + + # Tool is disabled, so the result should be offloaded normally + assert "[Offloaded:" in event.result["content"][0]["text"] + + @pytest.mark.asyncio + async def test_image_only_content_passes_through(self, plugin, mock_agent): + content = [{"image": {"format": "png", "source": {"bytes": b"fake"}}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_image_stored_and_placeholder_has_ref(self, plugin, storage, mock_agent): + img_bytes = b"\x89PNG" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Should have preview + image placeholder + assert len(event.result["content"]) == 2 + placeholder = event.result["content"][1]["text"] + assert "[image: png, 104 bytes" in placeholder + assert "ref:" in placeholder + + # Verify image was stored + assert len(storage._store) == 2 # text + image + img_ref = placeholder.split("ref: ")[1].rstrip("]") + img_content, img_type = storage.retrieve(img_ref) + assert img_content == img_bytes + assert img_type == "image/png" + + @pytest.mark.asyncio + async def test_document_stored_and_placeholder_has_ref(self, plugin, storage, mock_agent): + doc_bytes = b"%PDF-1.4" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": doc_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + assert len(event.result["content"]) == 2 + placeholder = event.result["content"][1]["text"] + assert "[document: pdf, report.pdf, 108 bytes" in placeholder + assert "ref:" in placeholder + + # Verify document was stored + doc_ref = placeholder.split("ref: ")[1].rstrip("]") + doc_content, doc_type = storage.retrieve(doc_ref) + assert doc_content == doc_bytes + assert doc_type == "application/pdf" + + @pytest.mark.asyncio + async def test_multiple_text_blocks_stored_separately(self, plugin, storage, mock_agent): + content = [ + {"text": "a" * 60}, + {"text": "b" * 60}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Two text blocks stored separately + assert len(storage._store) == 2 + refs = list(storage._store.keys()) + assert storage.retrieve(refs[0]) == (b"a" * 60, "text/plain") + assert storage.retrieve(refs[1]) == (b"b" * 60, "text/plain") + + @pytest.mark.asyncio + async def test_json_content_stored_as_json(self, plugin, storage, mock_agent): + large_json = {"data": [{"id": i, "value": "x" * 20} for i in range(10)]} + content = [{"json": large_json}] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + assert len(storage._store) == 1 + ref = list(storage._store.keys())[0] + stored_content, content_type = storage.retrieve(ref) + assert content_type == "application/json" + assert json.loads(stored_content) == large_json + + @pytest.mark.asyncio + async def test_mixed_text_and_json(self, plugin, storage, mock_agent): + content = [ + {"text": "a" * 60}, + {"json": {"key": "b" * 60}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Both stored separately with correct types + assert len(storage._store) == 2 + refs = list(storage._store.keys()) + assert storage.retrieve(refs[0])[1] == "text/plain" + assert storage.retrieve(refs[1])[1] == "application/json" + + @pytest.mark.asyncio + async def test_small_json_passes_through(self, plugin, mock_agent): + content = [{"json": {"key": "value"}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_error_status_still_offloaded(self, plugin, mock_agent): + large_text = "x" * 200 + event = _make_event(mock_agent, large_text, status="error") + + await plugin._handle_tool_result(event) + + assert "[Offloaded:" in event.result["content"][0]["text"] + assert event.result["status"] == "error" + + @pytest.mark.asyncio + async def test_storage_failure_keeps_original(self, mock_agent, caplog): + failing_storage = MagicMock() + failing_storage.store.side_effect = RuntimeError("disk full") + + plugin = ContextOffloader( + storage=failing_storage, + max_result_tokens=25, + preview_tokens=10, + ) + + large_text = "x" * 200 + event = _make_event(mock_agent, large_text) + + with caplog.at_level(logging.WARNING): + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == large_text + assert "failed to offload" in caplog.text + + @pytest.mark.asyncio + async def test_partial_storage_failure_keeps_original(self, mock_agent, caplog): + storage = MagicMock() + call_count = 0 + + def store_then_fail(key, content, content_type="text/plain"): + nonlocal call_count + call_count += 1 + if call_count > 1: + raise RuntimeError("disk full on second block") + return f"ref_{call_count}" + + storage.store.side_effect = store_then_fail + + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + + content = [ + {"text": "a" * 60}, + {"text": "b" * 60}, + ] + event = _make_event(mock_agent, content) + + with caplog.at_level(logging.WARNING): + await plugin._handle_tool_result(event) + + assert event.result["content"][0]["text"] == "a" * 60 + assert event.result["content"][1]["text"] == "b" * 60 + assert "failed to offload" in caplog.text + + @pytest.mark.asyncio + async def test_empty_text_blocks_not_stored(self, plugin, storage, mock_agent): + content = [ + {"text": ""}, + {"text": "x" * 200}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Empty text block is not in text_preview_parts but still iterated for storage + # The non-empty block triggers offloading + assert "[Offloaded:" in event.result["content"][0]["text"] + + @pytest.mark.asyncio + async def test_document_only_content_passes_through(self, plugin, mock_agent): + content = [{"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": b"pdf"}}}] + event = _make_event(mock_agent, content) + original_content = event.result["content"] + + await plugin._handle_tool_result(event) + + assert event.result["content"] is original_content + + @pytest.mark.asyncio + async def test_unknown_content_type_passed_through(self, plugin, mock_agent): + unknown_block = {"custom_type": {"data": "something"}} + content = [ + {"text": "x" * 200}, + unknown_block, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Unknown block should be passed through + assert event.result["content"][-1] is unknown_block + + @pytest.mark.asyncio + async def test_all_content_types_mixed(self, plugin, storage, mock_agent): + large_json = {"rows": [{"id": i} for i in range(20)]} + img_bytes = b"\x89PNG" + b"\x00" * 100 + doc_bytes = b"%PDF" + b"\x00" * 200 + content = [ + {"text": "a" * 60}, + {"json": large_json}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + {"document": {"format": "pdf", "name": "report.pdf", "source": {"bytes": doc_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + result_content = event.result["content"] + # Preview + image placeholder + document placeholder = 3 blocks + assert len(result_content) == 3 + assert "[Offloaded:" in result_content[0]["text"] + assert "[image: png" in result_content[1]["text"] + assert "[document: pdf, report.pdf" in result_content[2]["text"] + + # All 4 blocks stored + assert len(storage._store) == 4 + + @pytest.mark.asyncio + async def test_image_without_bytes_not_stored(self, plugin, storage, mock_agent): + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + # Only text stored, not the empty image + assert len(storage._store) == 1 + placeholder = event.result["content"][1]["text"] + assert "0 bytes" in placeholder + assert "ref:" not in placeholder + + +class TestRetrievalTool: + @pytest.fixture + def storage(self): + return InMemoryStorage() + + @pytest.fixture + def plugin(self, storage): + return ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + + @pytest.fixture + def mock_agent(self): + return MagicMock() + + @pytest.fixture + def tool_context(self, mock_agent): + tool_use = ToolUse(toolUseId="retrieve_1", name="retrieve_offloaded_content", input={}) + return ToolContext(tool_use=tool_use, agent=mock_agent, invocation_state={}) + + def test_retrieval_tool_registered_when_enabled(self, plugin): + tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" in tool_names + + def test_retrieval_tool_not_registered_by_default(self): + plugin = ContextOffloader(storage=InMemoryStorage()) + plugin.init_agent(MagicMock()) + tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" not in tool_names + + def test_retrieve_text_content(self, plugin, storage, tool_context): + ref = storage.store("key_1", b"hello world", "text/plain") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result == "hello world" + + def test_retrieve_json_content(self, plugin, storage, tool_context): + ref = storage.store("key_1", b'{"key": "value"}', "application/json") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["content"][0]["json"] == {"key": "value"} + + def test_retrieve_large_text_returns_full_content(self, plugin, storage, tool_context): + large_text = "a" * 50_000 + ref = storage.store("key_1", large_text.encode("utf-8"), "text/plain") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result == large_text + + def test_retrieve_missing_reference(self, plugin, tool_context): + result = plugin.retrieve_offloaded_content(reference="nonexistent", tool_context=tool_context) + assert "Error: reference not found" in result + + def test_retrieve_image_content(self, plugin, storage, tool_context): + img_bytes = b"\x89PNG\x00\x00" + ref = storage.store("key_1", img_bytes, "image/png") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["status"] == "success" + assert result["content"][0]["image"]["format"] == "png" + assert result["content"][0]["image"]["source"]["bytes"] == img_bytes + + def test_retrieve_document_content(self, plugin, storage, tool_context): + doc_bytes = b"%PDF-1.4 content" + ref = storage.store("key_1", doc_bytes, "application/pdf") + result = plugin.retrieve_offloaded_content(reference=ref, tool_context=tool_context) + assert result["status"] == "success" + assert result["content"][0]["document"]["format"] == "pdf" + assert result["content"][0]["document"]["source"]["bytes"] == doc_bytes + + +class TestInlineGuidance: + @pytest.fixture + def storage(self): + return InMemoryStorage() + + @pytest.fixture + def mock_agent(self): + agent = MagicMock() + agent.model = MagicMock() + agent.model.count_tokens = AsyncMock(side_effect=_heuristic_count_tokens) + return agent + + @pytest.mark.asyncio + async def test_guidance_mentions_retrieval_tool_when_enabled(self, storage, mock_agent): + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=True) + event = _make_event(mock_agent, "x" * 200) + await plugin._handle_tool_result(event) + result_text = event.result["content"][0]["text"] + assert "retrieve_offloaded_content" in result_text + + @pytest.mark.asyncio + async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, storage, mock_agent): + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + event = _make_event(mock_agent, "x" * 200) + await plugin._handle_tool_result(event) + result_text = event.result["content"][0]["text"] + assert "retrieve_offloaded_content" not in result_text + assert "available tools" in result_text diff --git a/tests/strands/vended_plugins/context_offloader/test_storage.py b/tests/strands/vended_plugins/context_offloader/test_storage.py new file mode 100644 index 000000000..6b9b9e962 --- /dev/null +++ b/tests/strands/vended_plugins/context_offloader/test_storage.py @@ -0,0 +1,266 @@ +"""Tests for offload storage backends.""" + +import threading +from unittest.mock import MagicMock, patch + +import pytest +from botocore.exceptions import ClientError + +from strands.vended_plugins.context_offloader import ( + FileStorage, + InMemoryStorage, + S3Storage, +) + + +class TestInMemoryStorage: + def test_round_trip(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b'{"a": 1}', "application/json") + content, content_type = storage.retrieve(ref) + assert content == b'{"a": 1}' + assert content_type == "application/json" + + def test_stores_binary_content(self): + storage = InMemoryStorage() + img_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + ref = storage.store("key_1", img_bytes, "image/png") + content, content_type = storage.retrieve(ref) + assert content == img_bytes + assert content_type == "image/png" + + def test_retrieve_missing_raises_key_error(self): + storage = InMemoryStorage() + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent_ref") + + def test_unique_references(self): + storage = InMemoryStorage() + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_reference_format(self): + storage = InMemoryStorage() + ref = storage.store("tool_abc", b"content") + assert ref.startswith("mem_") + assert "tool_abc" in ref + + def test_thread_safety(self): + storage = InMemoryStorage() + refs: list[str] = [] + errors: list[Exception] = [] + + def store_item(i: int): + try: + ref = storage.store(f"key_{i}", f"content_{i}".encode()) + refs.append(ref) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=store_item, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert len(set(refs)) == 50 + + def test_stores_empty_content(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"") + assert storage.retrieve(ref) == (b"", "text/plain") + + def test_clear(self): + storage = InMemoryStorage() + ref = storage.store("key_1", b"content") + storage.clear() + with pytest.raises(KeyError): + storage.retrieve(ref) + + def test_clear_empty_storage(self): + storage = InMemoryStorage() + storage.clear() + + +class TestFileStorage: + def test_round_trip(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b'{"a": 1}', "application/json") + content, content_type = storage.retrieve(ref) + assert content == b'{"a": 1}' + assert content_type == "application/json" + + def test_stores_binary_content(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + img_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + ref = storage.store("key_1", img_bytes, "image/png") + content, content_type = storage.retrieve(ref) + assert content == img_bytes + assert content_type == "image/png" + + def test_extension_from_content_type(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + assert storage.store("k", b"text", "text/plain").endswith(".txt") + assert storage.store("k", b"{}", "application/json").endswith(".json") + assert storage.store("k", b"img", "image/png").endswith(".png") + assert storage.store("k", b"pdf", "application/pdf").endswith(".pdf") + + def test_auto_creates_directory(self, tmp_path): + artifact_dir = tmp_path / "nested" / "dir" / "artifacts" + assert not artifact_dir.exists() + storage = FileStorage(artifact_dir=str(artifact_dir)) + storage.store("key_1", b"content") + assert artifact_dir.exists() + + def test_retrieve_missing_raises_key_error(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent.txt") + + def test_unique_references(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_sanitizes_path_traversal(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("../../etc/passwd", b"content") + assert ".." not in ref + assert "/" not in ref + + def test_metadata_survives_across_instances(self, tmp_path): + artifact_dir = str(tmp_path / "artifacts") + storage1 = FileStorage(artifact_dir=artifact_dir) + ref = storage1.store("key_1", b"hello", "image/png") + + storage2 = FileStorage(artifact_dir=artifact_dir) + content, content_type = storage2.retrieve(ref) + assert content == b"hello" + assert content_type == "image/png" + + def test_corrupt_metadata_fallback(self, tmp_path): + (tmp_path / ".metadata.json").write_text("not valid json", encoding="utf-8") + storage = FileStorage(artifact_dir=str(tmp_path)) + assert storage._content_types == {} + + def test_missing_metadata_fallback(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b"content", "image/png") + + storage._content_types.clear() + _, content_type = storage.retrieve(ref) + assert content_type == "application/octet-stream" + + def test_retrieve_rejects_path_traversal(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("../../etc/passwd") + + +class TestS3Storage: + @pytest.fixture + def mock_s3_client(self): + """Create a mock S3 client that stores objects in memory.""" + client = MagicMock() + objects: dict[str, tuple[bytes, str]] = {} + + def put_object(Bucket, Key, Body, ContentType="application/octet-stream", **kwargs): + objects[f"{Bucket}/{Key}"] = (Body, ContentType) + + def get_object(Bucket, Key, **kwargs): + full_key = f"{Bucket}/{Key}" + if full_key not in objects: + error_response = {"Error": {"Code": "NoSuchKey", "Message": "Not found"}} + raise ClientError(error_response, "GetObject") + body_bytes, ct = objects[full_key] + body = MagicMock() + body.read.return_value = body_bytes + return {"Body": body, "ContentType": ct} + + client.put_object.side_effect = put_object + client.get_object.side_effect = get_object + return client + + @pytest.fixture + def storage(self, mock_s3_client): + with patch("boto3.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session.client.return_value = mock_s3_client + mock_session_cls.return_value = mock_session + return S3Storage(bucket="test-bucket", prefix="artifacts") + + def test_round_trip(self, storage): + ref = storage.store("key_1", b"hello world") + content, content_type = storage.retrieve(ref) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_preserves_content_type(self, storage): + ref = storage.store("key_1", b"img", "image/png") + content, content_type = storage.retrieve(ref) + assert content == b"img" + assert content_type == "image/png" + + def test_retrieve_missing_raises_key_error(self, storage): + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("nonexistent_key") + + def test_unique_references(self, storage): + ref1 = storage.store("key_1", b"content a") + ref2 = storage.store("key_1", b"content b") + assert ref1 != ref2 + assert storage.retrieve(ref1)[0] == b"content a" + assert storage.retrieve(ref2)[0] == b"content b" + + def test_reference_includes_prefix(self, storage): + ref = storage.store("tool_abc", b"content") + assert ref.startswith("artifacts/") + + def test_empty_prefix(self, mock_s3_client): + with patch("boto3.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session.client.return_value = mock_s3_client + mock_session_cls.return_value = mock_session + storage = S3Storage(bucket="test-bucket", prefix="") + + ref = storage.store("tool_abc", b"content") + assert not ref.startswith("/") + assert storage.retrieve(ref)[0] == b"content" + + def test_put_object_called_with_correct_params(self, storage, mock_s3_client): + storage.store("key_1", b"test content", "application/json") + + mock_s3_client.put_object.assert_called_once() + call_kwargs = mock_s3_client.put_object.call_args[1] + assert call_kwargs["Bucket"] == "test-bucket" + assert call_kwargs["Key"].startswith("artifacts/") + assert call_kwargs["Body"] == b"test content" + assert call_kwargs["ContentType"] == "application/json" + + def test_non_nosuchkey_error_propagates(self, storage, mock_s3_client): + error_response = {"Error": {"Code": "AccessDenied", "Message": "Forbidden"}} + mock_s3_client.get_object.side_effect = ClientError(error_response, "GetObject") + + with pytest.raises(ClientError, match="Forbidden"): + storage.retrieve("some_key") From a49dc330e4ad5358cb77a24bbe3b1adf995290ce Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Fri, 24 Apr 2026 14:19:32 -0400 Subject: [PATCH 249/279] feat: override count_tokens with native token counting for supported providers (#2189) --- src/strands/models/anthropic.py | 50 +++++++- src/strands/models/bedrock.py | 52 ++++++++ src/strands/models/gemini.py | 61 +++++++++- src/strands/models/llamacpp.py | 56 ++++++++- src/strands/models/openai_responses.py | 51 +++++++- tests/strands/models/test_anthropic.py | 98 +++++++++++++++ tests/strands/models/test_bedrock.py | 112 ++++++++++++++++++ tests/strands/models/test_gemini.py | 99 ++++++++++++++++ tests/strands/models/test_llamacpp.py | 97 +++++++++++++++ tests/strands/models/test_openai_responses.py | 108 +++++++++++++++++ tests_integ/models/test_model_anthropic.py | 39 ++++++ tests_integ/models/test_model_bedrock.py | 35 ++++++ tests_integ/models/test_model_gemini.py | 38 ++++++ tests_integ/models/test_model_openai.py | 38 ++++++ 14 files changed, 930 insertions(+), 4 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 526460184..7bb38e1d4 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -16,7 +16,7 @@ from ..event_loop.streaming import process_stream from ..tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec @@ -371,6 +371,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"event_type=<{event['type']} | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Anthropic's native count_tokens API. + + Uses the same message format as the Messages API to get accurate token counts + directly from the Anthropic service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self.format_request(messages, tool_specs, system_prompt) + # Keep only fields accepted by count_tokens; strip inference params (max_tokens, temperature, etc.) + count_tokens_fields = {"model", "messages", "tools", "tool_choice", "system"} + request = {k: request[k] for k in request.keys() & count_tokens_fields} + + response = await self.client.messages.count_tokens(**request) + total_tokens: int = response.input_tokens + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.warning( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index e781b952e..2a468e450 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -749,6 +749,58 @@ def _generate_redaction_events(self) -> list[StreamEvent]: return events + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Bedrock's native CountTokens API. + + Uses the same message format as the Converse API to get accurate token counts + directly from the Bedrock service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + try: + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + request = self._format_request(messages, tool_specs, system_prompt_content) + converse_input: dict[str, Any] = {} + if "messages" in request: + converse_input["messages"] = request["messages"] + if "system" in request: + converse_input["system"] = request["system"] + if "toolConfig" in request: + converse_input["toolConfig"] = request["toolConfig"] + + response = await asyncio.to_thread( + self.client.count_tokens, + modelId=self.config["model_id"], + input={"converse": converse_input}, + ) + total_tokens: int = response["inputTokens"] + + logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) + return total_tokens + except Exception as e: + logger.warning( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 81c8bd76f..04e98f359 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -15,7 +15,7 @@ from google import genai from typing_extensions import Required, Unpack, override -from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages +from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -434,6 +434,65 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: # pragma: no cover raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using Gemini's native count_tokens API. + + Uses the Gemini count_tokens API for message contents. The Gemini API does not support + counting system_instruction or tools, so those are estimated via the base class heuristic. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + try: + contents = list(self._format_request_content(messages)) + + client = self._get_client().aio + response = await client.models.count_tokens( + model=self.config["model_id"], + contents=contents, + ) + if response.total_tokens is None: + raise ValueError("Gemini count_tokens returned None for total_tokens") + total_tokens: int = response.total_tokens + + # The google-genai SDK explicitly raises ValueError for system_instruction, tools, and + # generation_config in CountTokensConfig on the non-Vertex (mldev) backend. + # Use heuristic for these. + extra = await super().count_tokens( + messages=[], + tool_specs=tool_specs, + system_prompt=system_prompt, + system_prompt_content=system_prompt_content, + ) + total_tokens += extra + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.warning( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + async def stream( self, messages: Messages, diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index d689e65ea..2e6a83306 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -25,7 +25,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -508,6 +508,60 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using llama.cpp's native /tokenize endpoint. + + Sends the formatted prompt to the llama.cpp server's tokenization endpoint + to get an accurate token count. Requires a llama.cpp server version that supports + chat-template-aware tokenization via the ``messages`` field in /tokenize requests. + Older server versions that only accept ``{"content": "string"}`` are not supported + and will fall back to estimation. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self._format_request(messages, tool_specs, system_prompt) + payload = { + "messages": request["messages"], + **({"tools": request["tools"]} if request.get("tools") else {}), + } + + response = await self.client.post("/tokenize", json=payload) + response.raise_for_status() + data = response.json() + total_tokens: int = len(data.get("tokens", [])) + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config.get("model_id", "default"), + total_tokens, + ) + return total_tokens + except Exception as e: + logger.warning( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config.get("model_id", "default"), + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index f845c2688..0cb5cb43c 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -54,7 +54,7 @@ import openai # noqa: E402 - must import after version check from ..types.citations import WebLocationDict # noqa: E402 -from ..types.content import ContentBlock, Messages, Role # noqa: E402 +from ..types.content import ContentBlock, Messages, Role, SystemContentBlock # noqa: E402 from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 @@ -184,6 +184,55 @@ def get_config(self) -> OpenAIResponsesConfig: """ return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) + @override + async def count_tokens( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + ) -> int: + """Count tokens using the OpenAI Responses API input_tokens.count endpoint. + + Uses the same message format as the Responses API to get accurate token counts + directly from the OpenAI service. + + Args: + messages: List of message objects to count tokens for. + tool_specs: List of tool specifications to include in the count. + system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided. + system_prompt_content: Structured system prompt content blocks. + + Returns: + Total input token count. + """ + try: + # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, + # matching the behavior of stream(). The caller always provides system_prompt alongside + # system_prompt_content, so the plain string is always available. + request = self._format_request(messages, tool_specs, system_prompt) + # Keep only fields accepted by input_tokens.count + count_tokens_fields = {"model", "input", "instructions", "tools"} + request = {k: request[k] for k in request.keys() & count_tokens_fields} + + async with openai.AsyncOpenAI(**self.client_args) as client: + response = await client.responses.input_tokens.count(**request) + total_tokens: int = response.input_tokens + + logger.debug( + "model_id=<%s>, total_tokens=<%d> | native token count", + self.config["model_id"], + total_tokens, + ) + return total_tokens + except Exception as e: + logger.warning( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + self.config["model_id"], + e, + ) + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + @override async def stream( self, diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 74037fc00..eae08254e 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1040,3 +1040,101 @@ def model_dump_with_warning(): # Verify the message_stop event was still processed correctly assert {"messageStop": {"stopReason": mock_message_stop.message.stop_reason}} in events + + +class TestCountTokens: + """Tests for AnthropicModel.count_tokens native token counting.""" + + @pytest.fixture + def model_with_client(self, anthropic_client, model_id, max_tokens): + _ = anthropic_client + return AnthropicModel(model_id=model_id, max_tokens=max_tokens) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 42 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages) + + assert result == 42 + anthropic_client.messages.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 55 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert call_kwargs["system"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model_with_client, anthropic_client, messages, tool_specs): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 100 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + result = await model_with_client.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert "tools" in call_kwargs + + @pytest.mark.asyncio + async def test_max_tokens_stripped_from_request(self, model_with_client, anthropic_client, messages): + mock_response = unittest.mock.MagicMock() + mock_response.input_tokens = 10 + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(return_value=mock_response) + + await model_with_client.count_tokens(messages=messages) + + call_kwargs = anthropic_client.messages.count_tokens.call_args[1] + assert "max_tokens" not in call_kwargs + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model_with_client, anthropic_client, messages): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock( + side_effect=anthropic.APIError(message="Unsupported", request=unittest.mock.MagicMock(), body=None) + ) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model_with_client, anthropic_client, messages): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(side_effect=RuntimeError("Connection failed")) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_warning(self, model_with_client, anthropic_client, messages, caplog): + anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(side_effect=RuntimeError("API down")) + + with caplog.at_level(logging.WARNING): + await model_with_client.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 99a745b07..470f11b6b 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3103,3 +3103,115 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, assert citation["location"] == {"web": {"url": "https://example.com", "domain": "example.com"}} assert "title" not in citation assert "sourceContent" not in citation + + +class TestCountTokens: + """Tests for BedrockModel.count_tokens native token counting.""" + + @pytest.fixture + def model_with_client(self, bedrock_client, model_id): + _ = bedrock_client + return BedrockModel(model_id=model_id) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 42} + + result = await model_with_client.count_tokens(messages=messages) + + assert result == 42 + bedrock_client.count_tokens.assert_called_once() + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert "input" in call_kwargs + assert "converse" in call_kwargs["input"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 55} + + result = await model_with_client.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert call_kwargs["input"]["converse"]["system"] == [{"text": "Be helpful."}] + assert "toolConfig" not in call_kwargs["input"]["converse"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model_with_client, bedrock_client, messages, tool_specs): + bedrock_client.count_tokens.return_value = {"inputTokens": 100} + + result = await model_with_client.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert "toolConfig" in call_kwargs["input"]["converse"] + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt_content(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 60} + + result = await model_with_client.count_tokens( + messages=messages, + system_prompt_content=[{"text": "Be helpful."}, {"text": "Be concise."}], + ) + + assert result == 60 + call_kwargs = bedrock_client.count_tokens.call_args[1] + assert call_kwargs["input"]["converse"]["system"] == [{"text": "Be helpful."}, {"text": "Be concise."}] + + @pytest.mark.asyncio + async def test_native_count_tokens_strips_inference_config(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {"inputTokens": 10} + model_with_client.update_config(max_tokens=100) + + await model_with_client.count_tokens(messages=messages) + + call_kwargs = bedrock_client.count_tokens.call_args[1] + converse = call_kwargs["input"]["converse"] + assert "inferenceConfig" not in converse + assert "additionalModelRequestFields" not in converse + assert "guardrailConfig" not in converse + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Unsupported"}}, + "CountTokens", + ) + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.side_effect = RuntimeError("Connection failed") + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog): + bedrock_client.count_tokens.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.WARNING): + await model_with_client.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 361508327..e039861c6 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1105,3 +1105,102 @@ def test_format_request_filters_location_source_document(model, caplog): assert len(formatted_content) == 1 assert "text" in formatted_content[0] assert "Location sources are not supported by Gemini" in caplog.text + + +class TestCountTokens: + """Tests for GeminiModel.count_tokens native token counting.""" + + @pytest.fixture + def gemini_client(self): + with unittest.mock.patch.object(strands.models.gemini.genai, "Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.AsyncMock() + yield mock_client + + @pytest.fixture + def model(self, gemini_client): + _ = gemini_client + return GeminiModel(model_id="m1") + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 42 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert result == 42 + gemini_client.aio.models.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 55 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result > 55 # native (55) + heuristic estimate for system_prompt + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, gemini_client, messages, tool_specs): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = 100 + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result > 100 # native (100) + heuristic estimate for tool_specs + + @pytest.mark.asyncio + async def test_fallback_on_none_total_tokens(self, model, gemini_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.total_tokens = None + gemini_client.aio.models.count_tokens.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model, gemini_client, messages): + gemini_client.aio.models.count_tokens.side_effect = genai.errors.ClientError("Unsupported", response_json={}) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model, gemini_client, messages): + gemini_client.aio.models.count_tokens.side_effect = RuntimeError("Connection failed") + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_warning(self, model, gemini_client, messages, caplog): + gemini_client.aio.models.count_tokens.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.WARNING): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 3e023dfce..6a5be8060 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -706,3 +706,100 @@ def test_format_request_filters_location_source_document(caplog) -> None: assert len(user_content) == 1 assert user_content[0]["type"] == "text" assert "Location sources are not supported by llama.cpp" in caplog.text + + +class TestCountTokens: + """Tests for LlamaCppModel.count_tokens native token counting.""" + + @pytest.fixture + def model(self): + return LlamaCppModel(base_url="http://localhost:8080") + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, messages): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": [1, 2, 3, 4, 5]} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages) + + assert result == 5 + model.client.post.assert_called_once() + call_args = model.client.post.call_args + assert call_args[0][0] == "/tokenize" + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, messages): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": list(range(10))} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 10 + call_kwargs = model.client.post.call_args[1] + payload = call_kwargs["json"] + assert payload["messages"][0]["role"] == "system" + assert payload["messages"][0]["content"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, messages, tool_specs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"tokens": list(range(20))} + mock_response.raise_for_status = MagicMock() + model.client.post = AsyncMock(return_value=mock_response) + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 20 + call_kwargs = model.client.post.call_args[1] + payload = call_kwargs["json"] + assert "tools" in payload + + @pytest.mark.asyncio + async def test_fallback_on_http_error(self, model, messages): + model.client.post = AsyncMock( + side_effect=httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) + ) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_connection_error(self, model, messages): + model.client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_warning(self, model, messages, caplog): + model.client.post = AsyncMock(side_effect=RuntimeError("Server down")) + + with caplog.at_level(logging.WARNING): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index ef31cc1e6..7964d64b7 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1190,3 +1190,111 @@ def test_format_request_messages_excludes_reasoning_content(caplog): {"role": "user", "content": [{"type": "input_text", "text": "Thanks"}]}, ] assert "reasoningContent is not yet supported" in caplog.text + + +class TestCountTokens: + """Tests for OpenAIResponsesModel.count_tokens native token counting.""" + + @pytest.fixture + def openai_client(self): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + @pytest.fixture + def model(self, openai_client): + _ = openai_client + return OpenAIResponsesModel(model_id="gpt-4o") + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "hello"}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + + @pytest.mark.asyncio + async def test_native_count_tokens_success(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 42 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages) + + assert result == 42 + openai_client.responses.input_tokens.count.assert_called_once() + + @pytest.mark.asyncio + async def test_native_count_tokens_with_system_prompt(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 55 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages, system_prompt="Be helpful.") + + assert result == 55 + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert call_kwargs["instructions"] == "Be helpful." + + @pytest.mark.asyncio + async def test_native_count_tokens_with_tool_specs(self, model, openai_client, messages, tool_specs): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 100 + openai_client.responses.input_tokens.count.return_value = mock_response + + result = await model.count_tokens(messages=messages, tool_specs=tool_specs) + + assert result == 100 + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert "tools" in call_kwargs + + @pytest.mark.asyncio + async def test_stream_and_store_stripped(self, model, openai_client, messages): + mock_response = unittest.mock.AsyncMock() + mock_response.input_tokens = 10 + openai_client.responses.input_tokens.count.return_value = mock_response + + await model.count_tokens(messages=messages) + + call_kwargs = openai_client.responses.input_tokens.count.call_args[1] + assert "stream" not in call_kwargs + assert "store" not in call_kwargs + + @pytest.mark.asyncio + async def test_fallback_on_api_error(self, model, openai_client, messages): + openai_client.responses.input_tokens.count.side_effect = openai.APIError( + message="Unsupported", request=unittest.mock.MagicMock(), body=None + ) + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_on_generic_exception(self, model, openai_client, messages): + openai_client.responses.input_tokens.count.side_effect = RuntimeError("Connection failed") + + result = await model.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_fallback_logs_warning(self, model, openai_client, messages, caplog): + import logging + + openai_client.responses.input_tokens.count.side_effect = RuntimeError("API down") + + with caplog.at_level(logging.WARNING): + await model.count_tokens(messages=messages) + + assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 864360139..a5eba45b9 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -182,3 +182,42 @@ def test_input_and_max_tokens_exceed_context_limit(): with pytest.raises(ContextWindowOverflowException): agent(messages) + + +class TestCountTokens: + @pytest.fixture + def model(self): + return AnthropicModel( + model_id="claude-sonnet-4-20250514", + max_tokens=1024, + client_args={"api_key": os.environ["ANTHROPIC_API_KEY"]}, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 4020ce35e..d9e28e589 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -517,3 +517,38 @@ def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model): assert result.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( "Expected cacheWriteInputTokens > 0 even without TTL specified" ) + + +class TestCountTokens: + @pytest.fixture + def model(self): + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index 4c01c0b71..ac1943382 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -219,3 +219,41 @@ def test_agent_with_reasoning_content(model, assistant_agent): result = assistant_agent("Think about what 2+2 is") assert "reasoningContent" in result.message["content"][0] assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] + + +class TestCountTokens: + @pytest.fixture + def model(self): + return GeminiModel( + model_id="gemini-2.0-flash", + client_args={"api_key": os.environ["GOOGLE_API_KEY"]}, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 5a2d21570..bef526427 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -400,3 +400,41 @@ def test_responses_builtin_tool_shell(): result = agent("Use the shell to compute the md5sum of the string 'strands-test'. Return only the hash.") text = result.message["content"][0]["text"] assert "d82f373f079b00a1db7ef1eec7f15c68" in text + + +class TestOpenAIResponsesCountTokens: + @pytest.fixture + def model(self): + return OpenAIResponsesModel( + model_id="gpt-4o", + client_args={"api_key": os.environ["OPENAI_API_KEY"]}, + ) + + @pytest.fixture + def messages(self): + return [{"role": "user", "content": [{"text": "What is the capital of France? Explain in detail."}]}] + + @pytest.fixture + def tool_specs(self): + return [ + { + "name": "get_weather", + "description": "Get the current weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + + @pytest.mark.asyncio + async def test_count_tokens_messages_only(self, model, messages, caplog): + with caplog.at_level("DEBUG"): + result = await model.count_tokens(messages=messages) + assert isinstance(result, int) + assert result > 0 + assert "native token count" in caplog.text + assert "falling back" not in caplog.text + + @pytest.mark.asyncio + async def test_count_tokens_with_tools_greater_than_without(self, model, messages, tool_specs): + without = await model.count_tokens(messages=messages) + with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") + assert with_tools > without From ce64c3a9405251c847a7330cdf463bd4c599f8a6 Mon Sep 17 00:00:00 2001 From: afarntrog <47332252+afarntrog@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:41:57 -0400 Subject: [PATCH 250/279] fix(bedrock): upgrade default model to Claude Sonnet 4.5 (#2193) --- src/strands/models/bedrock.py | 12 ++-- tests/strands/agent/test_agent.py | 2 +- tests/strands/models/test_bedrock.py | 75 +++++++++++----------- tests_integ/conftest.py | 2 +- tests_integ/models/test_conformance.py | 2 +- tests_integ/steering/test_tool_steering.py | 29 +++++---- tests_integ/test_a2a_executor.py | 8 ++- tests_integ/test_bedrock_guardrails.py | 1 + tests_integ/test_context_overflow.py | 2 +- tests_integ/test_tool_context_injection.py | 4 +- 10 files changed, 77 insertions(+), 60 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 2a468e450..7f7113e83 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -36,8 +36,8 @@ logger = logging.getLogger(__name__) # See: `BedrockModel._get_default_model_with_warning` for why we need both -DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0" -_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-20250514-v1:0" +DEFAULT_BEDROCK_MODEL_ID = "global.anthropic.claude-sonnet-4-6" +_DEFAULT_BEDROCK_MODEL_ID = "{}.anthropic.claude-sonnet-4-6" DEFAULT_BEDROCK_REGION = "us-west-2" BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ @@ -90,7 +90,7 @@ class BedrockConfig(BaseModelConfig, total=False): guardrail_latest_message: Flag to send only the lastest user message to guardrails. Defaults to False. max_tokens: Maximum number of tokens to generate in the response - model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0") + model_id: The Bedrock model ID (e.g., "global.anthropic.claude-sonnet-4-6") include_tool_result_status: Flag to include status field in tool results. True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". service_tier: Service tier for the request, controlling the trade-off between latency and cost. @@ -1151,13 +1151,13 @@ def _get_default_model_with_warning(region_name: str, model_config: BedrockConfi region_name (str): region for bedrock model model_config (Optional[dict[str, Any]]): Model Config that caller passes in on init """ - if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): - return DEFAULT_BEDROCK_MODEL_ID - model_config = model_config or {} if model_config.get("model_id"): return model_config["model_id"] + if DEFAULT_BEDROCK_MODEL_ID != _DEFAULT_BEDROCK_MODEL_ID.format("us"): + return DEFAULT_BEDROCK_MODEL_ID + prefix_inference_map = {"ap": "apac"} # some inference endpoints can be a bit different than the region prefix prefix = "-".join(region_name.split("-")[:-2]).lower() # handles `us-east-1` or `us-gov-east-1` diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 3b9258e0a..680a1d23c 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -35,7 +35,7 @@ from tests.fixtures.mocked_model_provider import MockedModelProvider # For unit testing we will use the the us inference -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID @pytest.fixture diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 470f11b6b..b8e41d20a 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -16,7 +16,6 @@ from strands import _exception_notes from strands.models import BedrockModel, CacheConfig from strands.models.bedrock import ( - _DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, @@ -24,7 +23,7 @@ from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec -FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID.format("us") +FORMATTED_DEFAULT_MODEL_ID = DEFAULT_BEDROCK_MODEL_ID @pytest.fixture @@ -2213,43 +2212,24 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): def test_get_default_model_with_warning_supported_regions_shows_no_warning(captured_warnings): - """Test get_model_prefix_with_warning doesn't warn for supported region prefixes.""" + """Test _get_default_model_with_warning doesn't warn for any region (global profile works everywhere).""" BedrockModel._get_default_model_with_warning("us-west-2") BedrockModel._get_default_model_with_warning("eu-west-2") assert all("does not support" not in str(w.message) for w in captured_warnings) -def test_get_default_model_for_supported_eu_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("eu-west-1") - assert model_id == "eu.anthropic.claude-sonnet-4-20250514-v1:0" +def test_get_default_model_returns_global_inference_profile(captured_warnings): + """Default model id is the global inference profile regardless of region.""" + for region in ("us-east-1", "eu-west-1", "us-gov-west-1", "ap-southeast-1", "ca-central-1"): + assert BedrockModel._get_default_model_with_warning(region) == DEFAULT_BEDROCK_MODEL_ID assert all("does not support" not in str(w.message) for w in captured_warnings) -def test_get_default_model_for_supported_us_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_default_model_for_supported_gov_region_returns_correct_model_id(captured_warnings): - model_id = BedrockModel._get_default_model_with_warning("us-gov-west-1") - assert model_id == "us-gov.anthropic.claude-sonnet-4-20250514-v1:0" - assert all("does not support" not in str(w.message) for w in captured_warnings) - - -def test_get_model_prefix_for_ap_region_converts_to_apac_endpoint(captured_warnings): - """Test _get_default_model_with_warning warns for APAC regions since 'ap' is not in supported prefixes.""" - model_id = BedrockModel._get_default_model_with_warning("ap-southeast-1") - assert model_id == "apac.anthropic.claude-sonnet-4-20250514-v1:0" - - -def test_get_default_model_with_warning_unsupported_region_warns(captured_warnings): - """Test _get_default_model_with_warning warns for unsupported regions.""" +def test_get_default_model_with_warning_unsupported_region_does_not_warn(captured_warnings): + """Global inference profile works across all regions, so no region-support warning is emitted.""" BedrockModel._get_default_model_with_warning("ca-central-1") region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] - assert len(region_warnings) == 1 - assert "This region ca-central-1 does not support" in str(region_warnings[0].message) - assert "our default inference endpoint" in str(region_warnings[0].message) + assert len(region_warnings) == 0 def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured_warnings): @@ -2261,13 +2241,12 @@ def test_get_default_model_with_warning_no_warning_with_custom_model_id(captured assert len(captured_warnings) == 0 -def test_init_with_unsupported_region_warns(session_cls, captured_warnings): - """Test BedrockModel initialization warns for unsupported regions.""" +def test_init_with_unsupported_region_does_not_warn(session_cls, captured_warnings): + """BedrockModel initialization does not warn for 'unsupported' regions when using the global profile.""" BedrockModel(region_name="ca-central-1") region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] - assert len(region_warnings) == 1 - assert "This region ca-central-1 does not support" in str(region_warnings[0].message) + assert len(region_warnings) == 0 def test_init_with_unsupported_region_custom_model_no_warning(session_cls, captured_warnings): @@ -2282,10 +2261,34 @@ def test_override_default_model_id_uses_the_overriden_value(captured_warnings): assert model_id == "custom-overridden-model" -def test_no_override_uses_formatted_default_model_id(captured_warnings): +def test_default_model_sentinel_triggers_region_prefix_fallback(captured_warnings): + """When DEFAULT_BEDROCK_MODEL_ID matches the sentinel template, the region-prefix fallback runs.""" + sentinel = "us.anthropic.claude-sonnet-4-6" + with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", sentinel): + model_id = BedrockModel._get_default_model_with_warning("eu-west-1") + assert model_id == "eu.anthropic.claude-sonnet-4-6" + + +def test_caller_supplied_model_id_wins_over_global_default(captured_warnings): + """Caller-supplied model_id in config takes precedence over the global default.""" + model_config = {"model_id": "caller-supplied-model"} + model_id = BedrockModel._get_default_model_with_warning("us-east-1", model_config) + assert model_id == "caller-supplied-model" + + +def test_default_model_sentinel_with_unsupported_region_warns(captured_warnings): + """When the sentinel matches and the region is unknown, the region-unsupported warning fires.""" + sentinel = "us.anthropic.claude-sonnet-4-6" + with unittest.mock.patch("strands.models.bedrock.DEFAULT_BEDROCK_MODEL_ID", sentinel): + BedrockModel._get_default_model_with_warning("ca-central-1") + region_warnings = [w for w in captured_warnings if "does not support" in str(w.message)] + assert len(region_warnings) == 1 + + +def test_default_model_id_is_global_inference_profile(captured_warnings): model_id = BedrockModel._get_default_model_with_warning("us-east-1") - assert model_id == "us.anthropic.claude-sonnet-4-20250514-v1:0" - assert model_id != _DEFAULT_BEDROCK_MODEL_ID + assert model_id == "global.anthropic.claude-sonnet-4-6" + assert model_id == DEFAULT_BEDROCK_MODEL_ID assert all("does not support" not in str(w.message) for w in captured_warnings) diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py index b7ae78ec3..c696fb65d 100644 --- a/tests_integ/conftest.py +++ b/tests_integ/conftest.py @@ -203,7 +203,7 @@ def _load_api_keys_from_secrets_manager(): required_providers = { "ANTHROPIC_API_KEY", "GOOGLE_API_KEY", - "MISTRAL_API_KEY", + # "MISTRAL_API_KEY", # will add back once we get a card on file for this. "OPENAI_API_KEY", "WRITER_API_KEY", } diff --git a/tests_integ/models/test_conformance.py b/tests_integ/models/test_conformance.py index 36c21fb7f..994ecbf00 100644 --- a/tests_integ/models/test_conformance.py +++ b/tests_integ/models/test_conformance.py @@ -74,4 +74,4 @@ class UserProfile(BaseModel): result = agent("Create a profile for John who is a 25 year old dentist", structured_output_model=UserProfile) assert result.structured_output.name == "John" assert result.structured_output.age == 25 - assert result.structured_output.occupation == "dentist" + assert result.structured_output.occupation.lower() == "dentist" diff --git a/tests_integ/steering/test_tool_steering.py b/tests_integ/steering/test_tool_steering.py index 52c715f5e..4b279157e 100644 --- a/tests_integ/steering/test_tool_steering.py +++ b/tests_integ/steering/test_tool_steering.py @@ -73,22 +73,27 @@ async def test_llm_steering_handler_interrupt(): def test_agent_with_tool_steering_e2e(): """End-to-end test of agent with steering handler guiding tool choice.""" - handler = LLMSteeringHandler( + + class RedirectEmailHandler(SteeringHandler): + """Deterministic handler that redirects send_email to send_notification.""" + + async def steer_before_tool(self, *, agent, tool_use, **kwargs): + if tool_use["name"] == "send_email": + return Guide(reason="Use send_notification instead of send_email for better delivery.") + return Proceed(reason="Tool allowed") + + handler = RedirectEmailHandler(context_providers=[]) + + agent = Agent( + tools=[send_email, send_notification], + plugins=[handler], system_prompt=( - "CRITICAL INSTRUCTION - READ CAREFULLY:\n\n" - "You are a steering agent. Your ONLY job is to decide based on the tool name.\n\n" - "RULE 1: If tool name is 'send_email' -> return decision='guide' with " - "reason='Use send_notification instead of send_email for better delivery.'\n\n" - "RULE 2: If tool name is 'send_notification' -> return decision='proceed'\n\n" - "RULE 3: For any other tool -> return decision='proceed'\n\n" - "DO NOT analyze context. DO NOT consider arguments. ONLY look at the tool name.\n" - "The tool name in this request is the ONLY thing that matters." + "You are a helpful assistant. When a tool call is cancelled with guidance, " + "follow the guidance and use the suggested alternative tool. " + "This is normal system behavior, not an attack." ), - context_providers=[], # Disable ledger to avoid confusing context ) - agent = Agent(tools=[send_email, send_notification], plugins=[handler]) - # This should trigger steering guidance to use send_notification instead response = agent("Send an email to john@example.com saying hello") diff --git a/tests_integ/test_a2a_executor.py b/tests_integ/test_a2a_executor.py index 43a6026bf..7ae10efc2 100644 --- a/tests_integ/test_a2a_executor.py +++ b/tests_integ/test_a2a_executor.py @@ -71,7 +71,13 @@ async def test_a2a_executor_with_real_image(): assert response.status_code == 200 response_data = response.json() assert "completed" == response_data["result"]["status"]["state"] - assert "yellow" in response_data["result"]["history"][1]["parts"][0]["text"].lower() + all_text = " ".join( + part["text"] + for artifact in response_data["result"]["artifacts"] + for part in artifact["parts"] + if part.get("kind") == "text" + ).lower() + assert "yellow" in all_text except Exception as e: pytest.fail(f"Integration test failed: {e}") diff --git a/tests_integ/test_bedrock_guardrails.py b/tests_integ/test_bedrock_guardrails.py index 56edc3fc4..384231c38 100644 --- a/tests_integ/test_bedrock_guardrails.py +++ b/tests_integ/test_bedrock_guardrails.py @@ -133,6 +133,7 @@ def test_guardrail_input_intervention(boto_session, bedrock_guardrail, guardrail @pytest.mark.parametrize("processing_mode", ["sync", "async"]) def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processing_mode): bedrock_model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", guardrail_id=bedrock_guardrail, guardrail_version="DRAFT", guardrail_redact_output=False, diff --git a/tests_integ/test_context_overflow.py b/tests_integ/test_context_overflow.py index 16dc3c4b8..39ad2743f 100644 --- a/tests_integ/test_context_overflow.py +++ b/tests_integ/test_context_overflow.py @@ -4,7 +4,7 @@ def test_context_window_overflow(): messages: Messages = [ - {"role": "user", "content": [{"text": "Too much text!" * 100000}]}, + {"role": "user", "content": [{"text": "Too much text!" * 300000}]}, {"role": "assistant", "content": [{"text": "That was a lot of text!"}]}, ] diff --git a/tests_integ/test_tool_context_injection.py b/tests_integ/test_tool_context_injection.py index 215286a46..7d3525014 100644 --- a/tests_integ/test_tool_context_injection.py +++ b/tests_integ/test_tool_context_injection.py @@ -4,6 +4,7 @@ """ from strands import Agent, ToolContext, tool +from strands.models.bedrock import BedrockModel from strands.types.tools import ToolResult @@ -41,7 +42,8 @@ def _validate_tool_result_content(agent: Agent): def test_strands_context_integration_context_true(): """Test ToolContext functionality with real agent interactions.""" - agent = Agent(tools=[good_story]) + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") + agent = Agent(model=model, tools=[good_story]) agent("using a tool, write a good story") _validate_tool_result_content(agent) From b340dc4d390f7effbcf75db9311fdd20cf6ae67b Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:02:41 -0400 Subject: [PATCH 251/279] chore: update style guide for tool spec navigation (#2203) --- docs/STYLE_GUIDE.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/STYLE_GUIDE.md b/docs/STYLE_GUIDE.md index c17fb2b76..82ee51847 100644 --- a/docs/STYLE_GUIDE.md +++ b/docs/STYLE_GUIDE.md @@ -74,3 +74,16 @@ class EdgeCondition(Protocol): ``` Using `Protocol` with `**kwargs` allows the interface to evolve by adding new keyword arguments without breaking existing implementations that don't use them. + +### Tool Name References + +When comparing against tool names in hooks or plugins, use the tool instance's `tool_name` property instead of hardcoding strings. Tool specs can be modified at runtime via the `AgentTool.tool_spec` setter, so hardcoded names may not match the actual registered name. + +```python +# Good +if event.tool_use.get("name") == self.my_tool.tool_name: + ... + +# Bad — fragile if tool name is changed at runtime +if event.tool_use.get("name") == "my_tool": + ... \ No newline at end of file From 009374f4836e6ba54350b51979cdb4d08f4f06b9 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Tue, 28 Apr 2026 10:54:11 -0400 Subject: [PATCH 252/279] feat: add ProviderTokenCountError for native token counting failures (#2211) --- src/strands/models/bedrock.py | 6 +++++- src/strands/models/gemini.py | 4 ++-- src/strands/types/exceptions.py | 10 ++++++++++ tests/strands/models/test_bedrock.py | 9 +++++++++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 7f7113e83..94df5a84d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -27,6 +27,7 @@ from ..types.exceptions import ( ContextWindowOverflowException, ModelThrottledException, + ProviderTokenCountError, ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec @@ -789,7 +790,10 @@ async def count_tokens( modelId=self.config["model_id"], input={"converse": converse_input}, ) - total_tokens: int = response["inputTokens"] + input_tokens = response.get("inputTokens") + if input_tokens is None: + raise ProviderTokenCountError("Bedrock count_tokens returned None for inputTokens") + total_tokens: int = input_tokens logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 04e98f359..2ce1c0b42 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -16,7 +16,7 @@ from typing_extensions import Required, Unpack, override from ..types.content import ContentBlock, ContentBlockStartToolUse, Messages, SystemContentBlock -from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys @@ -465,7 +465,7 @@ async def count_tokens( contents=contents, ) if response.total_tokens is None: - raise ValueError("Gemini count_tokens returned None for total_tokens") + raise ProviderTokenCountError("Gemini count_tokens returned None for total_tokens") total_tokens: int = response.total_tokens # The google-genai SDK explicitly raises ValueError for system_instruction, tools, and diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 5db80a26e..7ad49eb24 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -83,6 +83,16 @@ class SnapshotException(Exception): pass +class ProviderTokenCountError(Exception): + """Thrown when a model provider's native token counting API fails. + + This error is used as internal control flow within provider ``count_tokens()`` overrides. + When caught, the provider falls back to the base class heuristic estimation. + """ + + pass + + class ToolProviderException(Exception): """Exception raised when a tool provider fails to load or cleanup tools.""" diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b8e41d20a..d63838182 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3210,6 +3210,15 @@ async def test_fallback_on_generic_exception(self, model_with_client, bedrock_cl assert isinstance(result, int) assert result >= 0 + @pytest.mark.asyncio + async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_client, messages): + bedrock_client.count_tokens.return_value = {} + + result = await model_with_client.count_tokens(messages=messages) + + assert isinstance(result, int) + assert result >= 0 + @pytest.mark.asyncio async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog): bedrock_client.count_tokens.side_effect = RuntimeError("API down") From bab08db8a036ec964ff26cfad5f58b0efc336ba6 Mon Sep 17 00:00:00 2001 From: Leoy Date: Wed, 29 Apr 2026 01:59:22 +0800 Subject: [PATCH 253/279] fix(conversation-manager): handle window_size=0 and reject negative values (#2208) --- .../sliding_window_conversation_manager.py | 12 ++++- .../agent/test_conversation_manager.py | 52 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f91d7a538..1b45dd42c 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -42,7 +42,7 @@ def __init__( Args: window_size: Maximum number of messages to keep in the agent's history. - Defaults to 40 messages. + Use 0 to clear all messages on every reduction. Defaults to 40 messages. should_truncate_results: Truncate tool results when a message is too large for the model's context window per_turn: Controls when to apply message management during agent execution. - False (default): Only apply management at the end (default behavior) @@ -56,8 +56,10 @@ def __init__( for performance tuning. Raises: - ValueError: If per_turn is 0 or a negative integer. + ValueError: If window_size is negative, or if per_turn is 0 or a negative integer. """ + if not isinstance(window_size, bool) and window_size < 0: + raise ValueError(f"window_size must be a non-negative integer, got {window_size}") if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") @@ -173,6 +175,12 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A """ messages = agent.messages + # window_size=0 means "remove all messages" (matches TypeScript SDK behaviour) + if self.window_size == 0: + self.removed_message_count += len(messages) + messages[:] = [] + return + # Try to truncate the tool result first oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index c8b9df1cf..8679e6fd7 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -703,3 +703,55 @@ def test_boundary_text_in_tool_result_not_truncated(): assert not changed assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text + + +# ============================================================================== +# window_size=0 and negative window_size validation tests +# ============================================================================== + + +def test_window_size_negative_raises_value_error(): + with pytest.raises(ValueError, match="window_size"): + SlidingWindowConversationManager(window_size=-1) + + +def test_window_size_zero_clears_all_messages_on_apply_management(): + """window_size=0 should remove all messages, matching TypeScript SDK behaviour (issue #2205).""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + test_agent = Agent(messages=messages) + manager.apply_management(test_agent) + + assert messages == [] + assert manager.removed_message_count == 2 + + +def test_window_size_zero_clears_all_messages_on_reduce_context(): + """reduce_context with window_size=0 should clear all messages even without overflow.""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Second"}]}, + {"role": "user", "content": [{"text": "Third"}]}, + ] + test_agent = Agent(messages=messages) + manager.reduce_context(test_agent) + + assert messages == [] + assert manager.removed_message_count == 3 + + +def test_window_size_zero_clears_on_overflow(): + """reduce_context with window_size=0 should clear messages even when called with an overflow exception.""" + manager = SlidingWindowConversationManager(window_size=0, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + test_agent = Agent(messages=messages) + manager.reduce_context(test_agent, e=Exception("overflow")) + + assert messages == [] From 52cdb9da0b140546352252168b5049a1b0c42302 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Tue, 28 Apr 2026 14:22:22 -0400 Subject: [PATCH 254/279] fix: change token counting fallback log from warning to debug (#2220) --- src/strands/models/anthropic.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/models/gemini.py | 2 +- src/strands/models/llamacpp.py | 2 +- src/strands/models/openai_responses.py | 2 +- tests/strands/models/test_anthropic.py | 4 ++-- tests/strands/models/test_bedrock.py | 4 ++-- tests/strands/models/test_gemini.py | 4 ++-- tests/strands/models/test_llamacpp.py | 4 ++-- tests/strands/models/test_openai_responses.py | 4 ++-- 10 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 7bb38e1d4..54fdaaf00 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -412,7 +412,7 @@ async def count_tokens( ) return total_tokens except Exception as e: - logger.warning( + logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", self.config["model_id"], e, diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 94df5a84d..1482d72e0 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -798,7 +798,7 @@ async def count_tokens( logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens except Exception as e: - logger.warning( + logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", self.config["model_id"], e, diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 2ce1c0b42..892dce52d 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -486,7 +486,7 @@ async def count_tokens( ) return total_tokens except Exception as e: - logger.warning( + logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", self.config["model_id"], e, diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 2e6a83306..c31ba11bc 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -555,7 +555,7 @@ async def count_tokens( ) return total_tokens except Exception as e: - logger.warning( + logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", self.config.get("model_id", "default"), e, diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 0cb5cb43c..73a889aad 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -226,7 +226,7 @@ async def count_tokens( ) return total_tokens except Exception as e: - logger.warning( + logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", self.config["model_id"], e, diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index eae08254e..8e004dbb7 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1131,10 +1131,10 @@ async def test_fallback_on_generic_exception(self, model_with_client, anthropic_ assert result >= 0 @pytest.mark.asyncio - async def test_fallback_logs_warning(self, model_with_client, anthropic_client, messages, caplog): + async def test_fallback_logs_debug(self, model_with_client, anthropic_client, messages, caplog): anthropic_client.messages.count_tokens = unittest.mock.AsyncMock(side_effect=RuntimeError("API down")) - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.DEBUG, logger="strands.models.anthropic"): await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index d63838182..3b158abbc 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3220,10 +3220,10 @@ async def test_fallback_on_none_input_tokens(self, model_with_client, bedrock_cl assert result >= 0 @pytest.mark.asyncio - async def test_fallback_logs_warning(self, model_with_client, bedrock_client, messages, caplog): + async def test_fallback_logs_debug(self, model_with_client, bedrock_client, messages, caplog): bedrock_client.count_tokens.side_effect = RuntimeError("API down") - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.DEBUG, logger="strands.models.bedrock"): await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index e039861c6..fe6936ccc 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1197,10 +1197,10 @@ async def test_fallback_on_generic_exception(self, model, gemini_client, message assert result >= 0 @pytest.mark.asyncio - async def test_fallback_logs_warning(self, model, gemini_client, messages, caplog): + async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog): gemini_client.aio.models.count_tokens.side_effect = RuntimeError("API down") - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.DEBUG, logger="strands.models.gemini"): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 6a5be8060..a891ec929 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -796,10 +796,10 @@ async def test_fallback_on_connection_error(self, model, messages): assert result >= 0 @pytest.mark.asyncio - async def test_fallback_logs_warning(self, model, messages, caplog): + async def test_fallback_logs_debug(self, model, messages, caplog): model.client.post = AsyncMock(side_effect=RuntimeError("Server down")) - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.DEBUG, logger="strands.models.llamacpp"): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 7964d64b7..88cbee326 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1289,12 +1289,12 @@ async def test_fallback_on_generic_exception(self, model, openai_client, message assert result >= 0 @pytest.mark.asyncio - async def test_fallback_logs_warning(self, model, openai_client, messages, caplog): + async def test_fallback_logs_debug(self, model, openai_client, messages, caplog): import logging openai_client.responses.input_tokens.count.side_effect = RuntimeError("API down") - with caplog.at_level(logging.WARNING): + with caplog.at_level(logging.DEBUG, logger="strands.models.openai_responses"): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) From e12ac9d9477c55197fea175ab6ee11b5d62162d2 Mon Sep 17 00:00:00 2001 From: Gastly Date: Wed, 29 Apr 2026 10:29:26 -0700 Subject: [PATCH 255/279] fix: do not synthesize exception for cancelled tools (#2106) --- src/strands/telemetry/tracer.py | 24 +++++++++------- src/strands/tools/executors/_executor.py | 3 +- tests/strands/telemetry/test_tracer.py | 13 +++++++++ .../strands/tools/executors/test_executor.py | 28 ++++++++++++++++--- 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index a422d3cbf..648a65d27 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -212,6 +212,8 @@ def _end_span( status_description = error_message or str(error) or type(error).__name__ span.set_status(StatusCode.ERROR, status_description) span.record_exception(error) + elif error_message: + span.set_status(StatusCode.ERROR, error_message) else: span.set_status(StatusCode.OK) except Exception as e: @@ -454,15 +456,13 @@ def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: error: Optional exception if the tool call failed. """ attributes: dict[str, AttributeValue] = {} + status: str | None = None + content: list[Any] = [] + if tool_result is not None: status = tool_result.get("status") - status_str = str(status) if status is not None else "" - - attributes.update( - { - "gen_ai.tool.status": status_str, - } - ) + content = tool_result.get("content", []) + attributes["gen_ai.tool.status"] = str(status) if status is not None else "" if self.use_latest_genai_conventions: self._add_event( @@ -477,7 +477,7 @@ def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "response": tool_result.get("content"), + "response": content, } ], } @@ -491,12 +491,16 @@ def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: span, "gen_ai.choice", event_attributes={ - "message": serialize(tool_result.get("content")), + "message": serialize(content), "id": tool_result.get("toolUseId", ""), }, ) - self._end_span(span, attributes, error) + if error is None and status == "error": + error_message = next((b["text"] for b in content if "text" in b), "tool returned error status") + self._end_span(span, attributes, error_message=error_message) + else: + self._end_span(span, attributes, error) def start_event_loop_cycle_span( self, diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 2c602a560..3993f332b 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -176,10 +176,9 @@ async def _stream( tool_use, invocation_state, cancel_result, - exception=Exception(cancel_message), cancel_message=cancel_message, ) - yield ToolResultEvent(after_event.result, exception=after_event.exception) + yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) return diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 8af7b782e..c7b096a5a 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -722,6 +722,19 @@ def test_end_tool_call_span_with_error(mock_span): mock_span.end.assert_called_once() +def test_end_tool_call_span_error_result_no_exception(mock_span): + """Test that an error result without an exception still sets StatusCode.ERROR.""" + tracer = Tracer() + tool_result = {"status": "error", "content": [{"text": "tool cancelled by user"}]} + + tracer.end_tool_call_span(mock_span, tool_result) + + mock_span.set_attributes.assert_called_once_with({"gen_ai.tool.status": "error"}) + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "tool cancelled by user") + mock_span.record_exception.assert_not_called() + mock_span.end.assert_called_once() + + def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 34b37dab0..9c38340b9 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -954,8 +954,8 @@ async def test_executor_stream_unknown_tool_has_exception(executor, agent, tool_ @pytest.mark.asyncio -async def test_executor_stream_cancel_has_exception(executor, agent, tool_results, invocation_state, alist): - """Test that _stream yields a ToolResultEvent with exception for cancelled tools.""" +async def test_executor_stream_cancel_no_exception(executor, agent, tool_results, invocation_state, alist): + """Test that _stream yields a ToolResultEvent with no exception for cancelled tools.""" def cancel_callback(event): event.cancel_tool = True @@ -969,5 +969,25 @@ def cancel_callback(event): result_event = events[-1] assert isinstance(result_event, ToolResultEvent) assert result_event.tool_result["status"] == "error" - assert result_event.exception is not None - assert "cancelled" in str(result_event.exception) + assert result_event.exception is None + + +@pytest.mark.asyncio +async def test_executor_stream_cancel_after_hook_sees_no_exception( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that AfterToolCallEvent.exception is None when a tool is cancelled.""" + + def cancel_callback(event): + event.cancel_tool = "user denied permission" + return event + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback) + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + await alist(stream) + + after_event = hook_events[-1] + assert isinstance(after_event, AfterToolCallEvent) + assert after_event.exception is None + assert after_event.cancel_message == "user denied permission" From 888c98c4307bb20351321044375403b0a419e22c Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 29 Apr 2026 14:19:14 -0400 Subject: [PATCH 256/279] feat: estimate input tokens before model calls (#2221) --- src/strands/agent/agent_result.py | 9 +++ src/strands/event_loop/event_loop.py | 50 +++++++++++++ src/strands/hooks/events.py | 5 ++ src/strands/telemetry/metrics.py | 19 +++++ tests/strands/agent/hooks/test_events.py | 19 +++++ tests/strands/agent/test_agent_hooks.py | 8 +- tests/strands/agent/test_agent_result.py | 14 ++++ tests/strands/event_loop/test_event_loop.py | 81 +++++++++++++++++++++ tests/strands/telemetry/test_metrics.py | 41 +++++++++++ 9 files changed, 242 insertions(+), 4 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f0a399f81..80e483088 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -44,6 +44,15 @@ def context_size(self) -> int | None: """ return self.metrics.latest_context_size + @property + def projected_context_size(self) -> int | None: + """Projected context size for the next model call. + + Returns: + The projected token count (inputTokens + outputTokens), or None if no data is available. + """ + return self.metrics.projected_context_size + def __str__(self) -> str: """Return a string representation of the agent result. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index bf1cc7a84..128ef9ca3 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -75,6 +75,48 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool: return False +async def _estimate_input_tokens(agent: "Agent") -> int: + """Estimate the input token count for the next model call. + + Reads inputTokens + outputTokens from the last assistant message's metadata as a known + baseline, then estimates only new messages added after it. Falls back to full estimation + when no metadata is available (cold start or first call). On cold start, tool specs are + resolved lazily so that the caller does not need to resolve them before BeforeModelCallEvent. + + Args: + agent: The agent instance with messages and model. + + Returns: + Estimated input token count. + """ + messages = agent.messages + + # Find the last assistant message with usage metadata + last_assistant_idx = -1 + for i, msg in reversed(list(enumerate(messages))): + if msg.get("role") == "assistant" and msg.get("metadata", {}).get("usage"): + last_assistant_idx = i + break + + if last_assistant_idx >= 0: + usage = messages[last_assistant_idx]["metadata"]["usage"] + known_baseline = usage["inputTokens"] + usage["outputTokens"] + new_messages = messages[last_assistant_idx + 1 :] + if not new_messages: + return known_baseline + # System prompt and tool spec tokens are already included in the baseline + return known_baseline + await agent.model.count_tokens(new_messages) + + # Cold start: resolve tool specs lazily for estimation only + tool_specs = agent.tool_registry.get_all_tool_specs() + return await agent.model.count_tokens( + messages, + tool_specs=tool_specs, + system_prompt=agent.system_prompt, + system_prompt_content=agent._system_prompt_content, + ) + + async def event_loop_cycle( agent: "Agent", invocation_state: dict[str, Any], @@ -325,10 +367,18 @@ async def _handle_model_execution( ) with trace_api.use_span(model_invoke_span, end_on_exit=False): try: + # Estimate input tokens for the upcoming model call (non-fatal) + projected_input_tokens: int | None = None + try: + projected_input_tokens = await _estimate_input_tokens(agent) + except Exception as e: + logger.debug("error=<%s> | token estimation failed, proceeding without estimate", e) + await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, invocation_state=invocation_state, + projected_input_tokens=projected_input_tokens, ) ) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 9186e0e70..80b50770a 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -236,9 +236,14 @@ class BeforeModelCallEvent(HookEvent): invocation_state: State and configuration passed through the agent invocation. This can include shared context for multi-agent coordination, request tracking, and dynamic configuration. + projected_input_tokens: Projected input token count for the upcoming model call. + Computed by the agent loop from message metadata and token estimation. + Available for hooks and plugins (e.g. conversation managers) to make + proactive decisions about context management. None if estimation failed. """ invocation_state: dict[str, Any] = field(default_factory=dict) + projected_input_tokens: int | None = None @dataclass diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index dae05965e..11690dd44 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -215,6 +215,25 @@ def latest_context_size(self) -> int | None: return self.agent_invocations[-1].cycles[-1].usage.get("inputTokens") return None + @property + def projected_context_size(self) -> int | None: + """Projected context size for the next model call. + + Computed as inputTokens + outputTokens from the most recent cycle's usage, + representing the approximate input token count for the next model call + (prior input + generated output that is now part of the conversation). + + Returns: + The projected token count, or None if no data is available. + """ + if self.agent_invocations and self.agent_invocations[-1].cycles: + usage = self.agent_invocations[-1].cycles[-1].usage + input_tokens = usage.get("inputTokens") + output_tokens = usage.get("outputTokens") + if input_tokens is not None and output_tokens is not None: + return input_tokens + output_tokens + return None + @property def _metrics_client(self) -> "MetricsClient": """Get the singleton MetricsClient instance.""" diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 0e03fbbcd..6771774d3 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -260,3 +260,22 @@ def test_after_invocation_event_resume_accepts_various_input_types(agent): # None to stop event.resume = None assert event.resume is None + + +def test_before_model_call_event_projected_input_tokens_default(agent): + """Test that projected_input_tokens defaults to None.""" + event = BeforeModelCallEvent(agent=agent) + assert event.projected_input_tokens is None + + +def test_before_model_call_event_projected_input_tokens_set(agent): + """Test that projected_input_tokens can be set at construction.""" + event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500) + assert event.projected_input_tokens == 500 + + +def test_before_model_call_event_projected_input_tokens_not_writable(agent): + """Test that projected_input_tokens is not writable after construction.""" + event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500) + with pytest.raises(AttributeError, match="Property projected_input_tokens is not writable"): + event.projected_input_tokens = 1000 diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 2c61ee966..bc2c376c2 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -165,7 +165,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, invocation_state=ANY, @@ -195,7 +195,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, invocation_state=ANY, @@ -239,7 +239,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m agent=agent, message=agent.messages[0], ) - assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, invocation_state=ANY, @@ -269,7 +269,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"}, ) assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2]) - assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY) + assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY) assert next(events) == AfterModelCallEvent( agent=agent, invocation_state=ANY, diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 64391f299..7cb106182 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -384,3 +384,17 @@ def test_context_size_none_when_no_data(mock_metrics, simple_message: Message): mock_metrics.latest_context_size = None result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) assert result.context_size is None + + +def test_projected_context_size_delegates_to_metrics(mock_metrics, simple_message: Message): + """Test that projected_context_size delegates to metrics.projected_context_size.""" + mock_metrics.projected_context_size = 15000 + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.projected_context_size == 15000 + + +def test_projected_context_size_none_when_no_data(mock_metrics, simple_message: Message): + """Test that projected_context_size returns None when metrics has no data.""" + mock_metrics.projected_context_size = None + result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={}) + assert result.projected_context_size is None diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 871371f5f..f025a81ef 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1198,3 +1198,84 @@ async def test_event_loop_metrics_recorded_before_recursion( # Verify the event loop completed successfully tru_stop_reason, _, _, _, _, _ = events[-1]["stop"] assert tru_stop_reason == "end_turn" + + +class TestEstimateInputTokens: + """Tests for _estimate_input_tokens helper.""" + + @pytest.mark.asyncio + async def test_cold_start_estimates_all_messages(self): + """On cold start (no prior usage metadata), estimates all messages with lazily resolved tool specs.""" + agent = unittest.mock.AsyncMock() + agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}] + agent.system_prompt = "You are helpful" + agent._system_prompt_content = None + agent.tool_registry = unittest.mock.MagicMock() + agent.tool_registry.get_all_tool_specs.return_value = [{"name": "tool1"}] + agent.model.count_tokens = AsyncMock(return_value=42) + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + assert result == 42 + agent.tool_registry.get_all_tool_specs.assert_called_once() + agent.model.count_tokens.assert_called_once_with( + agent.messages, + tool_specs=[{"name": "tool1"}], + system_prompt="You are helpful", + system_prompt_content=None, + ) + + @pytest.mark.asyncio + async def test_baseline_only_no_new_messages(self): + """When last message is assistant with usage and no new messages after, returns baseline.""" + agent = unittest.mock.AsyncMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Hi"}]}, + { + "role": "assistant", + "content": [{"text": "Hello"}], + "metadata": {"usage": {"inputTokens": 100, "outputTokens": 20, "totalTokens": 120}}, + }, + ] + agent.system_prompt = "You are helpful" + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + assert result == 120 + agent.model.count_tokens.assert_not_called() + + @pytest.mark.asyncio + async def test_baseline_plus_delta(self): + """When new messages exist after last assistant, adds estimated delta to baseline.""" + agent = unittest.mock.AsyncMock() + agent.messages = [ + {"role": "user", "content": [{"text": "Hi"}]}, + { + "role": "assistant", + "content": [{"text": "Hello"}], + "metadata": {"usage": {"inputTokens": 100, "outputTokens": 30, "totalTokens": 130}}, + }, + {"role": "user", "content": [{"text": "tool result"}]}, + ] + agent.system_prompt = "You are helpful" + agent.model.count_tokens = AsyncMock(return_value=50) + + result = await strands.event_loop.event_loop._estimate_input_tokens(agent) + + # baseline (100+30) + delta (50) = 180 + assert result == 180 + agent.model.count_tokens.assert_called_once() + + @pytest.mark.asyncio + async def test_error_fallback_returns_none_at_call_site(self): + """When count_tokens raises, the caller catches and sets projected_input_tokens to None.""" + agent = unittest.mock.AsyncMock() + agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}] + agent.system_prompt = "You are helpful" + agent._system_prompt_content = None + agent.tool_registry = unittest.mock.MagicMock() + agent.tool_registry.get_all_tool_specs.return_value = [] + agent.model.count_tokens = AsyncMock(side_effect=Exception("API unavailable")) + + with pytest.raises(Exception, match="API unavailable"): + await strands.event_loop.event_loop._estimate_input_tokens(agent) diff --git a/tests/strands/telemetry/test_metrics.py b/tests/strands/telemetry/test_metrics.py index c38fa6a18..7d54c0cc6 100644 --- a/tests/strands/telemetry/test_metrics.py +++ b/tests/strands/telemetry/test_metrics.py @@ -613,3 +613,44 @@ def test_latest_context_size_missing_input_tokens_key(event_loop_metrics): ) ) assert event_loop_metrics.latest_context_size is None + + +def test_projected_context_size_no_invocations(event_loop_metrics): + assert event_loop_metrics.projected_context_size is None + + +def test_projected_context_size_invocation_with_no_cycles(event_loop_metrics): + event_loop_metrics.reset_usage_metrics() + assert event_loop_metrics.projected_context_size is None + + +def test_projected_context_size_returns_input_plus_output(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + assert event_loop_metrics.projected_context_size == 150 + + +def test_projected_context_size_updates_across_cycles(event_loop_metrics, mock_get_meter_provider): + event_loop_metrics.reset_usage_metrics() + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"}) + event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150)) + + event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"}) + event_loop_metrics.update_usage(Usage(inputTokens=200, outputTokens=80, totalTokens=280)) + + assert event_loop_metrics.projected_context_size == 280 + + +def test_projected_context_size_missing_tokens_key(event_loop_metrics): + """Returns None when usage dict is missing inputTokens or outputTokens.""" + event_loop_metrics.reset_usage_metrics() + invocation = event_loop_metrics.agent_invocations[-1] + invocation.cycles.append( + strands.telemetry.metrics.EventLoopCycleMetric( + event_loop_cycle_id="c1", + usage={"outputTokens": 50, "totalTokens": 50}, + ) + ) + assert event_loop_metrics.projected_context_size is None From e88b2767472ae4d001c2c0c721ef4b138cb06a0a Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Wed, 29 Apr 2026 14:24:23 -0400 Subject: [PATCH 257/279] feat(offloader): return explicit paths in preview and auto-enable retrieval (#2222) --- .../context_offloader/plugin.py | 9 +-- .../context_offloader/storage.py | 45 +++++++++++---- .../context_offloader/test_plugin.py | 56 ++++++++++++++++++- .../context_offloader/test_storage.py | 43 ++++++++++++-- 4 files changed, 131 insertions(+), 22 deletions(-) diff --git a/src/strands/vended_plugins/context_offloader/plugin.py b/src/strands/vended_plugins/context_offloader/plugin.py index 0072d3934..929ba3ca6 100644 --- a/src/strands/vended_plugins/context_offloader/plugin.py +++ b/src/strands/vended_plugins/context_offloader/plugin.py @@ -88,7 +88,7 @@ class ContextOffloader(Plugin): max_result_tokens: Offload results whose estimated token count exceeds this threshold. preview_tokens: Number of tokens to keep as a text preview in context. include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` tool. - Defaults to False. + Defaults to True. Example: ```python @@ -109,7 +109,7 @@ def __init__( max_result_tokens: int = _DEFAULT_MAX_RESULT_TOKENS, preview_tokens: int = _DEFAULT_PREVIEW_TOKENS, *, - include_retrieval_tool: bool = False, + include_retrieval_tool: bool = True, ) -> None: """Initialize the ContextOffloader plugin. @@ -121,7 +121,7 @@ def __init__( Uses tiktoken for exact slicing when available, falls back to chars/4 heuristic. Defaults to ``_DEFAULT_PREVIEW_TOKENS`` (1,000). include_retrieval_tool: Whether to register the ``retrieve_offloaded_content`` - tool so the agent can fetch offloaded content. Defaults to False. + tool so the agent can fetch offloaded content. Defaults to True. Raises: ValueError: If max_result_tokens is not positive, preview_tokens is negative, @@ -155,7 +155,8 @@ def retrieve_offloaded_content( """Retrieve offloaded content by reference. Use this tool when you see a placeholder with a reference (ref: ...) - and need the full content. + and need the full content. Only use this as a fallback if the data + cannot be accessed using your existing tools. Args: reference: The reference string from the offload placeholder. diff --git a/src/strands/vended_plugins/context_offloader/storage.py b/src/strands/vended_plugins/context_offloader/storage.py index a12055a2e..645d2cb09 100644 --- a/src/strands/vended_plugins/context_offloader/storage.py +++ b/src/strands/vended_plugins/context_offloader/storage.py @@ -131,7 +131,11 @@ def _extension_for(content_type: str) -> str: return f".{content_type.split('/')[-1]}" def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: - """Store content as a file and return the filename as reference. + """Store content as a file and return the path as reference. + + The returned path preserves the form of ``artifact_dir`` passed to + the constructor: a relative ``artifact_dir`` yields a relative + reference, an absolute one yields an absolute reference. Args: key: A unique key for this content block. @@ -139,7 +143,7 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s content_type: MIME type of the content. Returns: - The filename (not full path) used as the reference. + The file path (e.g., ``./artifacts/1234_1_key.txt``). """ self._artifact_dir.mkdir(parents=True, exist_ok=True) @@ -156,13 +160,16 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s file_path = self._artifact_dir / filename file_path.write_bytes(content) - return filename + return str(file_path) def retrieve(self, reference: str) -> tuple[bytes, str]: """Retrieve content from a stored file. + Accepts both full paths (as returned by ``store()``) and bare + filenames for backward compatibility. + Args: - reference: The filename reference returned by store(). + reference: The file path or filename returned by store(). Returns: A tuple of (content bytes, content type). @@ -170,12 +177,17 @@ def retrieve(self, reference: str) -> tuple[bytes, str]: Raises: KeyError: If the file does not exist. """ - file_path = (self._artifact_dir / reference).resolve() - if not file_path.is_relative_to(self._artifact_dir.resolve()): + resolved_dir = self._artifact_dir.resolve() + ref_path = Path(reference) + file_path = ref_path.resolve() if len(ref_path.parts) > 1 else (self._artifact_dir / reference).resolve() + if not file_path.is_relative_to(resolved_dir): + file_path = (self._artifact_dir / reference).resolve() + if not file_path.is_relative_to(resolved_dir): raise KeyError(f"Reference not found: {reference}") if not file_path.is_file(): raise KeyError(f"Reference not found: {reference}") - content_type = self._content_types.get(reference, "application/octet-stream") + filename = file_path.name + content_type = self._content_types.get(filename, "application/octet-stream") return file_path.read_bytes(), content_type def _load_metadata(self) -> dict[str, str]: @@ -320,7 +332,7 @@ def __init__( self._lock = threading.Lock() def store(self, key: str, content: bytes, content_type: str = "text/plain") -> str: - """Store content as an S3 object and return the object key as reference. + """Store content as an S3 object and return an ``s3://`` URI as reference. Args: key: A unique key for this content block. @@ -328,7 +340,7 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s content_type: MIME type of the content. Returns: - The S3 object key used as the reference. + An S3 URI (e.g., ``s3://bucket/prefix/1234_1_key``). Raises: botocore.exceptions.ClientError: If the S3 operation fails (e.g., bucket @@ -348,13 +360,16 @@ def store(self, key: str, content: bytes, content_type: str = "text/plain") -> s ContentType=content_type, ) - return s3_key + return f"s3://{self._bucket}/{s3_key}" def retrieve(self, reference: str) -> tuple[bytes, str]: """Retrieve content from an S3 object. + Accepts both ``s3://`` URIs (as returned by ``store()``) and raw + S3 keys for backward compatibility. + Args: - reference: The S3 object key returned by store(). + reference: The S3 URI or object key returned by store(). Returns: A tuple of (content bytes, content type). @@ -362,8 +377,14 @@ def retrieve(self, reference: str) -> tuple[bytes, str]: Raises: KeyError: If the object does not exist. """ + s3_key = reference + if reference.startswith("s3://"): + expected_prefix = f"s3://{self._bucket}/" + if not reference.startswith(expected_prefix): + raise KeyError(f"Reference not found: {reference}") + s3_key = reference[len(expected_prefix) :] try: - response = self._client.get_object(Bucket=self._bucket, Key=reference) + response = self._client.get_object(Bucket=self._bucket, Key=s3_key) content: bytes = response["Body"].read() content_type: str = response.get("ContentType", "application/octet-stream") return content, content_type diff --git a/tests/strands/vended_plugins/context_offloader/test_plugin.py b/tests/strands/vended_plugins/context_offloader/test_plugin.py index 528d1f006..fb9471dbf 100644 --- a/tests/strands/vended_plugins/context_offloader/test_plugin.py +++ b/tests/strands/vended_plugins/context_offloader/test_plugin.py @@ -11,6 +11,7 @@ from strands.types.tools import ToolContext, ToolUse from strands.vended_plugins.context_offloader import ( ContextOffloader, + FileStorage, InMemoryStorage, ) @@ -26,6 +27,7 @@ def plugin(storage): storage=storage, max_result_tokens=25, preview_tokens=10, + include_retrieval_tool=False, ) @@ -466,10 +468,16 @@ def test_retrieval_tool_registered_when_enabled(self, plugin): tool_names = [t.tool_name for t in plugin.tools] assert "retrieve_offloaded_content" in tool_names - def test_retrieval_tool_not_registered_by_default(self): + def test_retrieval_tool_registered_by_default(self): plugin = ContextOffloader(storage=InMemoryStorage()) plugin.init_agent(MagicMock()) tool_names = [t.tool_name for t in plugin.tools] + assert "retrieve_offloaded_content" in tool_names + + def test_retrieval_tool_not_registered_when_disabled(self): + plugin = ContextOffloader(storage=InMemoryStorage(), include_retrieval_tool=False) + plugin.init_agent(MagicMock()) + tool_names = [t.tool_name for t in plugin.tools] assert "retrieve_offloaded_content" not in tool_names def test_retrieve_text_content(self, plugin, storage, tool_context): @@ -531,9 +539,53 @@ async def test_guidance_mentions_retrieval_tool_when_enabled(self, storage, mock @pytest.mark.asyncio async def test_guidance_does_not_mention_retrieval_tool_when_disabled(self, storage, mock_agent): - plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + plugin = ContextOffloader( + storage=storage, max_result_tokens=25, preview_tokens=10, include_retrieval_tool=False + ) event = _make_event(mock_agent, "x" * 200) await plugin._handle_tool_result(event) result_text = event.result["content"][0]["text"] assert "retrieve_offloaded_content" not in result_text assert "available tools" in result_text + + +class TestActionableReferences: + """Tests that storage-specific references appear in the offloaded preview.""" + + @pytest.mark.asyncio + async def test_file_storage_path_in_preview(self, tmp_path, mock_agent): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + event = _make_event(mock_agent, "a" * 200) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert str(tmp_path / "artifacts") in result_text + + @pytest.mark.asyncio + async def test_file_storage_image_placeholder_has_path(self, tmp_path, mock_agent): + storage = FileStorage(artifact_dir=str(tmp_path / "artifacts")) + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + img_bytes = b"\x89PNG" + b"\x00" * 100 + content = [ + {"text": "x" * 200}, + {"image": {"format": "png", "source": {"bytes": img_bytes}}}, + ] + event = _make_event(mock_agent, content) + + await plugin._handle_tool_result(event) + + placeholder = event.result["content"][1]["text"] + assert str(tmp_path / "artifacts") in placeholder + + @pytest.mark.asyncio + async def test_inmemory_storage_opaque_reference_in_preview(self, mock_agent): + storage = InMemoryStorage() + plugin = ContextOffloader(storage=storage, max_result_tokens=25, preview_tokens=10) + event = _make_event(mock_agent, "a" * 200) + + await plugin._handle_tool_result(event) + + result_text = event.result["content"][0]["text"] + assert "mem_" in result_text diff --git a/tests/strands/vended_plugins/context_offloader/test_storage.py b/tests/strands/vended_plugins/context_offloader/test_storage.py index 6b9b9e962..898dd5f86 100644 --- a/tests/strands/vended_plugins/context_offloader/test_storage.py +++ b/tests/strands/vended_plugins/context_offloader/test_storage.py @@ -1,6 +1,7 @@ """Tests for offload storage backends.""" import threading +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -147,7 +148,30 @@ def test_sanitizes_path_traversal(self, tmp_path): storage = FileStorage(artifact_dir=str(tmp_path)) ref = storage.store("../../etc/passwd", b"content") assert ".." not in ref - assert "/" not in ref + assert "/" not in Path(ref).name + + def test_reference_includes_artifact_dir(self, tmp_path): + artifact_dir = str(tmp_path / "artifacts") + storage = FileStorage(artifact_dir=artifact_dir) + ref = storage.store("key_1", b"content") + assert Path(ref).parent == Path(artifact_dir) + + def test_relative_artifact_dir_gives_relative_reference(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + storage = FileStorage(artifact_dir="./artifacts") + ref = storage.store("key_1", b"content") + assert Path(ref).parent == Path("artifacts") + content, content_type = storage.retrieve(ref) + assert content == b"content" + assert content_type == "text/plain" + + def test_retrieve_accepts_bare_filename(self, tmp_path): + storage = FileStorage(artifact_dir=str(tmp_path)) + ref = storage.store("key_1", b"hello world") + filename = Path(ref).name + content, content_type = storage.retrieve(filename) + assert content == b"hello world" + assert content_type == "text/plain" def test_metadata_survives_across_instances(self, tmp_path): artifact_dir = str(tmp_path / "artifacts") @@ -233,9 +257,9 @@ def test_unique_references(self, storage): assert storage.retrieve(ref1)[0] == b"content a" assert storage.retrieve(ref2)[0] == b"content b" - def test_reference_includes_prefix(self, storage): + def test_reference_is_s3_uri(self, storage): ref = storage.store("tool_abc", b"content") - assert ref.startswith("artifacts/") + assert ref.startswith("s3://test-bucket/artifacts/") def test_empty_prefix(self, mock_s3_client): with patch("boto3.Session") as mock_session_cls: @@ -245,9 +269,20 @@ def test_empty_prefix(self, mock_s3_client): storage = S3Storage(bucket="test-bucket", prefix="") ref = storage.store("tool_abc", b"content") - assert not ref.startswith("/") + assert ref.startswith("s3://test-bucket/") assert storage.retrieve(ref)[0] == b"content" + def test_retrieve_accepts_raw_key(self, storage, mock_s3_client): + ref = storage.store("key_1", b"hello world") + raw_key = ref.removeprefix("s3://test-bucket/") + content, content_type = storage.retrieve(raw_key) + assert content == b"hello world" + assert content_type == "text/plain" + + def test_retrieve_rejects_wrong_bucket_uri(self, storage): + with pytest.raises(KeyError, match="Reference not found"): + storage.retrieve("s3://wrong-bucket/artifacts/some_key") + def test_put_object_called_with_correct_params(self, storage, mock_s3_client): storage.store("key_1", b"test content", "application/json") From 771a86ac13e8fa324cf08394345044e5e34cf8f5 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:33:40 -0400 Subject: [PATCH 258/279] fix: update tests to use non-EOL'd model (#2226) Co-authored-by: Mackenzie Zastrow --- tests_integ/models/providers.py | 2 +- tests_integ/models/test_model_bedrock.py | 2 +- tests_integ/models/test_model_litellm.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 15161b9cb..db85d496d 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -91,7 +91,7 @@ def __init__(self): ), ) litellm = ProviderInfo( - id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + id="litellm", factory=lambda: LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0") ) llama = ProviderInfo( id="llama", diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index d9e28e589..d9d44317d 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -279,7 +279,7 @@ def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow def test_redacted_content_handling(): """Test redactedContent handling with thinking mode.""" bedrock_model = BedrockModel( - model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", additional_request_fields={ "thinking": { "type": "enabled", diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b09983d73..b606771d0 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -12,17 +12,17 @@ @pytest.fixture def model(): - return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0") @pytest.fixture def streaming_model(): - return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": True}) + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", params={"stream": True}) @pytest.fixture def non_streaming_model(): - return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": False}) + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0", params={"stream": False}) @pytest.fixture From 6e208a8c546a6a9e27351cf289e730a13b07fc55 Mon Sep 17 00:00:00 2001 From: Hatim Kagalwala Date: Wed, 29 Apr 2026 14:17:23 -0700 Subject: [PATCH 259/279] =?UTF-8?q?feat(bedrock):=20add=20strict=5Ftools?= =?UTF-8?q?=20config=20with=20auto-inject=20of=20additional=E2=80=A6=20(#2?= =?UTF-8?q?213)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Venkatesh Bhukya --- src/strands/models/_strict_schema.py | 144 ++++++++++ src/strands/models/bedrock.py | 14 +- tests/strands/models/test_bedrock.py | 182 +++++++++++++ tests/strands/models/test_strict_schema.py | 302 +++++++++++++++++++++ tests_integ/models/test_model_bedrock.py | 24 ++ 5 files changed, 665 insertions(+), 1 deletion(-) create mode 100644 src/strands/models/_strict_schema.py create mode 100644 tests/strands/models/test_strict_schema.py diff --git a/src/strands/models/_strict_schema.py b/src/strands/models/_strict_schema.py new file mode 100644 index 000000000..e7f13e244 --- /dev/null +++ b/src/strands/models/_strict_schema.py @@ -0,0 +1,144 @@ +"""Strict JSON schema transformation for tool definitions. + +When model providers require `strict: true` on tool definitions, they also require +`"additionalProperties": false` on every `object` type in the input schema. This module +provides a utility to recursively apply that constraint. + +Modeled after OpenAI's `_ensure_strict_json_schema`: +https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py +""" + +import copy +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def ensure_strict_json_schema( + schema: dict[str, Any], + *, + require_all_properties: bool = False, +) -> dict[str, Any]: + """Ensure a JSON schema conforms to strict tool use requirements. + + Creates a deep copy of the schema and recursively: + 1. Adds ``"additionalProperties": false`` to all ``object`` types that do not already define it + 2. Optionally adds all properties to the ``required`` array (needed for OpenAI) + 3. Handles ``$defs``, ``definitions``, ``anyOf``, ``allOf``, ``items``, and ``$ref`` + + Args: + schema: The JSON schema to process. A deep copy is made internally so the original is not mutated. + require_all_properties: If True, set ``required`` to include all property keys. OpenAI strict mode + requires this; Bedrock and Anthropic do not. + + Returns: + A new schema dict with strict-mode constraints applied. + """ + schema_copy = copy.deepcopy(schema) + _apply_strict(schema_copy, root=schema_copy, require_all_properties=require_all_properties) + return schema_copy + + +def _apply_strict( + schema: dict[str, Any], + *, + root: dict[str, Any], + require_all_properties: bool, +) -> None: + """Recursively apply strict-mode constraints to a JSON schema in place. + + Args: + schema: The schema node to process (modified in place). + root: The root schema, used for resolving ``$ref`` pointers. + require_all_properties: If True, add all properties to ``required``. + """ + # Process $defs / definitions blocks + for defs_key in ("$defs", "definitions"): + defs = schema.get(defs_key) + if isinstance(defs, dict): + for def_schema in defs.values(): + if isinstance(def_schema, dict): + _apply_strict(def_schema, root=root, require_all_properties=require_all_properties) + + # Add additionalProperties: false to object types that lack it + if schema.get("type") == "object" and "additionalProperties" not in schema: + schema["additionalProperties"] = False + + # Process properties and optionally enforce required + properties = schema.get("properties") + if isinstance(properties, dict): + if require_all_properties: + schema["required"] = list(properties.keys()) + + for prop_schema in properties.values(): + if isinstance(prop_schema, dict): + _apply_strict(prop_schema, root=root, require_all_properties=require_all_properties) + + # Process array items + items = schema.get("items") + if isinstance(items, dict): + _apply_strict(items, root=root, require_all_properties=require_all_properties) + + # Process anyOf variants + any_of = schema.get("anyOf") + if isinstance(any_of, list): + for variant in any_of: + if isinstance(variant, dict): + _apply_strict(variant, root=root, require_all_properties=require_all_properties) + + # Process allOf variants + all_of = schema.get("allOf") + if isinstance(all_of, list): + for entry in all_of: + if isinstance(entry, dict): + _apply_strict(entry, root=root, require_all_properties=require_all_properties) + + # Process oneOf variants + one_of = schema.get("oneOf") + if isinstance(one_of, list): + for variant in one_of: + if isinstance(variant, dict): + _apply_strict(variant, root=root, require_all_properties=require_all_properties) + + # Resolve $ref combined with other keys by inlining the referenced schema + ref = schema.get("$ref") + if isinstance(ref, str) and len(schema) > 1: + resolved = _resolve_ref(root, ref) + if isinstance(resolved, dict): + # Inline the resolved schema, giving priority to existing keys + merged = {**copy.deepcopy(resolved), **schema} + merged.pop("$ref", None) + schema.clear() + schema.update(merged) + # Re-apply strict to the inlined schema + _apply_strict(schema, root=root, require_all_properties=require_all_properties) + + +def _resolve_ref(root: dict[str, Any], ref: str) -> dict[str, Any] | None: + """Resolve a JSON Schema ``$ref`` pointer against the root schema. + + Args: + root: The root schema containing definitions. + ref: A JSON pointer string (e.g., ``#/$defs/MyModel``). + + Returns: + The resolved schema dict, or None if resolution fails. + """ + if not ref.startswith("#/"): + logger.warning("ref=<%s> | unexpected $ref format, skipping resolution", ref) + return None + + path = ref[2:].split("/") + current: Any = root + for key in path: + if not isinstance(current, dict) or key not in current: + logger.warning("ref=<%s> | failed to resolve $ref path", ref) + return None + current = current[key] + + if not isinstance(current, dict): + logger.warning("ref=<%s> | resolved to non-dict value", ref) + return None + + return current diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 1482d72e0..d535bbc51 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -31,6 +31,7 @@ ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec +from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys from .model import BaseModelConfig, CacheConfig, Model @@ -100,6 +101,10 @@ class BedrockConfig(BaseModelConfig, total=False): supported service tiers, models, and regions stop_sequences: List of sequences that will stop generation when encountered streaming: Flag to enable/disable streaming. Defaults to True. + strict_tools: Flag to enable structured output enforcement on tool definitions. + When True, adds strict: true to each tool spec and automatically injects + "additionalProperties": false into all object types in tool input schemas. + See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) """ @@ -125,6 +130,7 @@ class BedrockConfig(BaseModelConfig, total=False): service_tier: str | None stop_sequences: list[str] | None streaming: bool | None + strict_tools: bool | None temperature: float | None top_p: float | None @@ -240,6 +246,7 @@ def _format_request( # Use system_prompt_content directly (copy for mutability) system_blocks: list[SystemContentBlock] = system_prompt_content.copy() if system_prompt_content else [] + # Add cache point if configured (backwards compatibility) if cache_prompt := self.config.get("cache_prompt"): warnings.warn( @@ -261,7 +268,12 @@ def _format_request( "toolSpec": { "name": tool_spec["name"], "description": tool_spec["description"], - "inputSchema": tool_spec["inputSchema"], + "inputSchema": ( + {"json": ensure_strict_json_schema(tool_spec["inputSchema"]["json"])} + if self.config.get("strict_tools") + else tool_spec["inputSchema"] + ), + **({"strict": True} if self.config.get("strict_tools") else {}), } } for tool_spec in tool_specs diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 3b158abbc..a80ca091e 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -493,6 +493,188 @@ def test_format_request_tool_specs(model, messages, model_id, tool_spec): assert tru_request == exp_request +def test_format_request_strict_tools_injects_strict_and_closes_schema(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert tool_spec_result == { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + "additionalProperties": False, + } + }, + "strict": True, + } + + +def test_format_request_strict_tools_does_not_mutate_original(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + model._format_request(messages, tool_specs=tool_specs) + + assert "additionalProperties" not in tool_specs[0]["inputSchema"]["json"] + + +def test_format_request_strict_tools_preserves_additional_properties_true(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + "additionalProperties": True, + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"] + + assert schema["additionalProperties"] is True + + +def test_format_request_strict_tools_nested_objects(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"value": {"type": "integer"}}, + } + }, + "required": ["config"], + } + }, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + schema = request["toolConfig"]["tools"][0]["toolSpec"]["inputSchema"]["json"] + + assert schema == { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"value": {"type": "integer"}}, + "additionalProperties": False, + } + }, + "required": ["config"], + "additionalProperties": False, + } + + +def test_format_request_strict_tools_default_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + }, + } + ] + model = BedrockModel(model_id=model_id) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + assert tool_spec_result["inputSchema"]["json"] == { + "type": "object", + "properties": {"param": {"type": "string"}}, + "required": ["param"], + } + + +def test_format_request_strict_tools_false_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}}, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=False) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + + +def test_format_request_strict_tools_none_no_strict(bedrock_client, model_id, messages): + tool_specs = [ + { + "name": "my_tool", + "description": "A tool", + "inputSchema": {"json": {"type": "object", "properties": {"x": {"type": "string"}}}}, + } + ] + model = BedrockModel(model_id=model_id, strict_tools=None) + request = model._format_request(messages, tool_specs=tool_specs) + tool_spec_result = request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_spec_result + + +def test_format_request_strict_tools_applies_to_all_tools(bedrock_client, model_id, messages): + tool_specs = [ + {"name": "tool_a", "description": "Tool A", "inputSchema": {"json": {"type": "object", "properties": {}}}}, + {"name": "tool_b", "description": "Tool B", "inputSchema": {"json": {"type": "object", "properties": {}}}}, + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + + for tool in request["toolConfig"]["tools"]: + if "toolSpec" in tool: + assert tool["toolSpec"]["strict"] is True + assert tool["toolSpec"]["inputSchema"]["json"]["additionalProperties"] is False + + def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): tool_choice = {"auto": {}} tru_request = model._format_request(messages, [tool_spec], tool_choice=tool_choice) diff --git a/tests/strands/models/test_strict_schema.py b/tests/strands/models/test_strict_schema.py new file mode 100644 index 000000000..4e69f767d --- /dev/null +++ b/tests/strands/models/test_strict_schema.py @@ -0,0 +1,302 @@ +from strands.models._strict_schema import ensure_strict_json_schema + + +def test_basic_object(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + assert "additionalProperties" not in schema + + +def test_nested_objects(): + schema = { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": {"inner": {"type": "integer"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "outer": { + "type": "object", + "properties": {"inner": {"type": "integer"}}, + "additionalProperties": False, + } + }, + "additionalProperties": False, + } + + +def test_defs(): + schema = { + "type": "object", + "properties": {"item": {"$ref": "#/$defs/MyItem"}}, + "$defs": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + assert result["$defs"]["MyItem"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_definitions(): + schema = { + "type": "object", + "properties": {"item": {"$ref": "#/definitions/MyItem"}}, + "definitions": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + assert result["definitions"]["MyItem"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_ref_inline(): + schema = { + "type": "object", + "properties": { + "item": { + "$ref": "#/$defs/MyItem", + "description": "An item", + } + }, + "$defs": { + "MyItem": { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["properties"]["item"] == { + "type": "object", + "properties": {"name": {"type": "string"}}, + "description": "An item", + "additionalProperties": False, + } + + +def test_ref_inline_uses_deep_copy(): + """Two properties referencing the same $def get independent copies.""" + schema = { + "type": "object", + "properties": { + "a": {"$ref": "#/$defs/Shared", "description": "first"}, + "b": {"$ref": "#/$defs/Shared", "description": "second"}, + }, + "$defs": { + "Shared": { + "type": "object", + "properties": {"val": {"type": "string"}}, + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result["properties"]["a"]["description"] == "first" + assert result["properties"]["b"]["description"] == "second" + assert result["properties"]["a"] is not result["properties"]["b"] + + +def test_arrays_anyof_allof(): + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "object", "properties": {"a": {"type": "string"}}}, + }, + "union": { + "anyOf": [ + {"type": "object", "properties": {"b": {"type": "string"}}}, + {"type": "null"}, + ] + }, + "intersection": { + "allOf": [ + {"type": "object", "properties": {"c": {"type": "string"}}}, + ] + }, + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"a": {"type": "string"}}, + "additionalProperties": False, + }, + }, + "union": { + "anyOf": [ + { + "type": "object", + "properties": {"b": {"type": "string"}}, + "additionalProperties": False, + }, + {"type": "null"}, + ] + }, + "intersection": { + "allOf": [ + { + "type": "object", + "properties": {"c": {"type": "string"}}, + "additionalProperties": False, + }, + ] + }, + }, + "additionalProperties": False, + } + + +def test_oneof(): + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"a": {"type": "string"}}}, + {"type": "object", "properties": {"b": {"type": "integer"}}}, + ] + } + }, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"a": {"type": "string"}}, "additionalProperties": False}, + {"type": "object", "properties": {"b": {"type": "integer"}}, "additionalProperties": False}, + ] + } + }, + "additionalProperties": False, + } + + +def test_require_all_properties(): + schema = { + "type": "object", + "properties": { + "required_field": {"type": "string"}, + "optional_field": {"type": "string"}, + }, + "required": ["required_field"], + } + + without = ensure_strict_json_schema(schema) + assert without["required"] == ["required_field"] + + with_all = ensure_strict_json_schema(schema, require_all_properties=True) + assert set(with_all["required"]) == {"required_field", "optional_field"} + + +def test_preserves_additional_properties_true(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": True, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": True, + } + + +def test_preserves_additional_properties_false(): + schema = { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + result = ensure_strict_json_schema(schema) + + assert result == { + "type": "object", + "properties": {"x": {"type": "string"}}, + "additionalProperties": False, + } + + +def test_non_object_type_unchanged(): + schema = {"type": "string"} + result = ensure_strict_json_schema(schema) + + assert result == {"type": "string"} + + +def test_ref_with_invalid_format_is_ignored(): + """A $ref that doesn't start with #/ is silently skipped.""" + schema = { + "type": "object", + "properties": { + "item": {"$ref": "external.json#/Foo", "description": "ext"}, + }, + } + result = ensure_strict_json_schema(schema) + + # $ref is not resolved, but additionalProperties is still added to root + assert result["additionalProperties"] is False + assert result["properties"]["item"]["$ref"] == "external.json#/Foo" + + +def test_ref_with_missing_path_is_ignored(): + """A $ref pointing to a non-existent path is silently skipped.""" + schema = { + "type": "object", + "properties": { + "item": {"$ref": "#/$defs/Missing", "description": "gone"}, + }, + "$defs": {}, + } + result = ensure_strict_json_schema(schema) + + assert result["additionalProperties"] is False + # $ref stays because resolution failed + assert "$ref" in result["properties"]["item"] diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index d9d44317d..73d67f414 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -552,3 +552,27 @@ async def test_count_tokens_with_tools_greater_than_without(self, model, message without = await model.count_tokens(messages=messages) with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") assert with_tools > without + + +def test_strict_tools_with_complex_schema(): + """Test strict_tools=True with tools that have complex schemas including arrays and optional params.""" + + tools_called = set() + + @strands.tool + def search(query: str, tags: list[str], max_results: int = 5) -> str: + """Search for items matching query and tags.""" + tools_called.add("search") + return f"Found results for '{query}' with tags {tags} (limit {max_results})" + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + tools_called.add("calculator") + return eval(expression) + + model = BedrockModel(strict_tools=True) + agent = Agent(model=model, tools=[search, calculator], load_tools_from_directory=False) + agent('Search for "python" with tags ["programming", "language"] using the search tool.') + + assert "search" in tools_called From a245e6d3e4ee6da7bfb1da4a86d782d566424121 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Fri, 1 May 2026 17:14:26 -0400 Subject: [PATCH 260/279] feat: enable openai provider use aws profile (#2230) --- pyproject.toml | 2 +- src/strands/models/_openai_bedrock.py | 126 ++++++++++++++ src/strands/models/openai.py | 38 ++++- src/strands/models/openai_responses.py | 36 +++- tests/strands/models/test_openai.py | 157 ++++++++++++++++++ tests/strands/models/test_openai_responses.py | 123 ++++++++++++++ tests_integ/models/test_model_mantle.py | 71 ++++---- 7 files changed, 500 insertions(+), 53 deletions(-) create mode 100644 src/strands/models/_openai_bedrock.py diff --git a/pyproject.toml b/pyproject.toml index 83a7bbf4d..8a017cd07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ litellm = ["litellm>=1.75.9,<=1.83.13", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] -openai = ["openai>=1.68.0,<3.0.0"] +openai = ["openai>=1.68.0,<3.0.0", "aws-bedrock-token-generator>=1.1.0,<2.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", diff --git a/src/strands/models/_openai_bedrock.py b/src/strands/models/_openai_bedrock.py new file mode 100644 index 000000000..149a47ec5 --- /dev/null +++ b/src/strands/models/_openai_bedrock.py @@ -0,0 +1,126 @@ +"""Internal helpers for routing OpenAI-compatible clients to Bedrock Mantle. + +Converts a ``bedrock_mantle_config`` dict into the ``base_url`` and ``api_key`` that the +OpenAI Python SDK consumes. Tokens are minted on demand via +``aws_bedrock_token_generator.provide_token`` so long-running agents survive the +bearer token's maximum lifetime. + +``aws_bedrock_token_generator`` is part of the ``openai`` extras group +(``pip install strands-agents[openai]``) but is *not* included in the ``litellm`` +or ``sagemaker`` extras, which also pull in the ``openai`` package. The import is +therefore lazy — it happens inside :func:`resolve_bedrock_client_args` so that +those other extras never trigger an ``ImportError`` at module load. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any, TypedDict + +import boto3 +from botocore.credentials import CredentialProvider + +_MANTLE_BASE_URL_TEMPLATE = "https://bedrock-mantle.{region}.api.aws/v1" +_MANTLE_DOCS_URL = "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html" + + +class BedrockMantleConfig(TypedDict, total=False): + """Config for routing an OpenAI-compatible client through Bedrock Mantle. + + Attributes: + region: AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved + from ``boto_session`` (if provided) or the standard boto3 chain + (``AWS_REGION`` / ``AWS_DEFAULT_REGION`` / active profile / EC2 metadata). + A :class:`ValueError` is raised if none resolve. + boto_session: Optional :class:`boto3.Session` used to resolve the region when + ``region`` is not provided. Useful for picking up a non-default profile + without exporting env vars. + credentials_provider: Optional botocore :class:`~botocore.credentials.CredentialProvider` + forwarded to ``provide_token``. Omit to let the token generator use the + standard AWS credential chain. + expiry: Optional ``timedelta`` for the bearer token's lifetime, forwarded to + ``provide_token``. Defaults to the generator's built-in lifetime when + omitted. + """ + + region: str + boto_session: boto3.Session + credentials_provider: CredentialProvider + expiry: timedelta + + +def _resolve_region(config: BedrockMantleConfig) -> str: + """Resolve the AWS region, preferring explicit config then falling back to boto3. + + Raises: + ValueError: If no region can be resolved from the config, an attached session, + or the standard boto3 credential chain. + """ + region = config.get("region") + if region: + return region + + session = config.get("boto_session") + if session is not None and session.region_name: + return str(session.region_name) + + # ``boto3.Session()`` with no args reads ``AWS_REGION`` / ``AWS_DEFAULT_REGION``, + # the active profile, and falls back to EC2 instance metadata — the same chain + # :class:`BedrockModel` uses. + default_region = boto3.Session().region_name + if default_region: + return str(default_region) + + raise ValueError( + "Could not resolve an AWS region for Bedrock Mantle. Pass 'region' in " + "bedrock_mantle_config, attach a boto_session with a configured region, or set " + f"AWS_REGION in the environment. See {_MANTLE_DOCS_URL} for supported regions." + ) + + +def resolve_bedrock_client_args( + config: BedrockMantleConfig, client_args: dict[str, Any] | None = None +) -> dict[str, Any]: + """Resolve a ``BedrockMantleConfig`` (plus optional ``client_args``) into OpenAI client kwargs. + + Mints a fresh bearer token on every call. Callers are expected to validate that + ``client_args`` does not contain ``base_url`` or ``api_key`` before calling this + function (typically at ``__init__`` time for fail-fast behavior). + + Raises: + ValueError: If no region can be resolved. + ImportError: If ``aws-bedrock-token-generator`` is not installed. + RuntimeError: If token minting fails (e.g. missing AWS credentials). + """ + region = _resolve_region(config) + + # ``aws-bedrock-token-generator`` is included in the ``openai`` extras group but not in + # ``litellm`` or ``sagemaker`` (which also depend on the ``openai`` package). The lazy + # import keeps those extras from hitting an ImportError at module load. + try: + from aws_bedrock_token_generator import provide_token + except ImportError as e: + raise ImportError( + "bedrock_mantle_config requires the 'aws-bedrock-token-generator' package. " + "Install it with: pip install strands-agents[openai]" + ) from e + + # Only forward kwargs the user set; provide_token rejects expiry=None. + token_kwargs: dict[str, Any] = {"region": region} + if "credentials_provider" in config: + token_kwargs["aws_credentials_provider"] = config["credentials_provider"] + if "expiry" in config: + token_kwargs["expiry"] = config["expiry"] + + try: + token = provide_token(**token_kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to mint Bedrock Mantle bearer token for region '{region}'. " + "Verify your AWS credentials and network connectivity." + ) from e + + resolved: dict[str, Any] = dict(client_args or {}) + resolved["base_url"] = _MANTLE_BASE_URL_TEMPLATE.format(region=region) + resolved["api_key"] = token + return resolved diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c4be7d360..ea16c7713 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -21,6 +21,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -71,6 +72,7 @@ def __init__( self, client: Client | None = None, client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, **model_config: Unpack[OpenAIConfig], ) -> None: """Initialize provider instance. @@ -87,23 +89,50 @@ def __init__( Note: The client should not be shared across different asyncio event loops. client_args: Arguments for the OpenAI client (legacy approach). For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, + ``bedrock_mantle_config`` derives ``base_url`` and ``api_key`` (which must not + appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. Cannot be + combined with a pre-built ``client``. **model_config: Configuration options for the OpenAI model. Raises: - ValueError: If both `client` and `client_args` are provided. + ValueError: If ``client`` is combined with ``client_args`` or ``bedrock_mantle_config``. """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) - # Validate that only one client configuration method is provided - if client is not None and client_args is not None and len(client_args) > 0: + # client_args + bedrock_mantle_config is allowed; the config derives base_url / api_key. + client_args_provided = client_args is not None and len(client_args) > 0 + if client is not None and client_args_provided: raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + if bedrock_mantle_config is not None and client is not None: + raise ValueError("'bedrock_mantle_config' cannot be combined with a pre-built 'client'.") + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) self._custom_client = client self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config logger.debug("config=<%s> | initializing", self.config) + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + @override def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] """Update the OpenAI model configuration with the provided arguments. @@ -590,11 +619,10 @@ async def _get_client(self) -> AsyncIterator[Any]: # Use the injected client (caller manages lifecycle) yield self._custom_client else: - # Create a new client from client_args # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please # refer to https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: yield client @override diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 73a889aad..4aff07ccd 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -58,6 +58,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402 from ._validation import validate_config_keys # noqa: E402 from .model import BaseModelConfig, Model # noqa: E402 @@ -141,21 +142,48 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: bool def __init__( - self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig] + self, + client_args: dict[str, Any] | None = None, + bedrock_mantle_config: BedrockMantleConfig | None = None, + **model_config: Unpack[OpenAIResponsesConfig], ) -> None: """Initialize provider instance. Args: client_args: Arguments for the OpenAI client. For a complete list of supported arguments, see https://pypi.org/project/openai/. + May be combined with ``bedrock_mantle_config``; when both are set, the config + derives ``base_url`` and ``api_key`` (which must not appear in ``client_args``). + bedrock_mantle_config: Route requests through Amazon Bedrock's Mantle + (OpenAI-compatible) endpoint. See :class:`BedrockMantleConfig` for accepted + keys. When set, a fresh bearer token is minted on every request. **model_config: Configuration options for the OpenAI Responses API model. """ validate_config_keys(model_config, self.OpenAIResponsesConfig) self.config = dict(model_config) + self.client_args = client_args or {} + self._bedrock_mantle_config = bedrock_mantle_config + + if bedrock_mantle_config is not None and client_args: + conflicting = [k for k in ("api_key", "base_url") if k in client_args] + if conflicting: + raise ValueError( + f"client_args must not contain {conflicting} when bedrock_mantle_config is set; " + "these are derived from the Mantle config automatically." + ) logger.debug("config=<%s> | initializing", self.config) + def _resolve_client_args(self) -> dict[str, Any]: + """Return the kwargs to pass to ``openai.AsyncOpenAI`` for the current request. + + Delegates to :func:`resolve_bedrock_client_args` when ``bedrock_mantle_config`` is set. + """ + if self._bedrock_mantle_config is not None: + return resolve_bedrock_client_args(self._bedrock_mantle_config, self.client_args) + return self.client_args + @property @override def stateful(self) -> bool: @@ -215,7 +243,7 @@ async def count_tokens( count_tokens_fields = {"model", "input", "instructions", "tools"} request = {k: request[k] for k in request.keys() & count_tokens_fields} - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: response = await client.responses.input_tokens.count(**request) total_tokens: int = response.input_tokens @@ -267,7 +295,7 @@ async def stream( logger.debug("invoking OpenAI Responses API model") - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: try: response = await client.responses.create(**request) @@ -447,7 +475,7 @@ async def structured_output( ContextWindowOverflowException: If the input exceeds the model's context window. ModelThrottledException: If the request is throttled by OpenAI (rate limits). """ - async with openai.AsyncOpenAI(**self.client_args) as client: + async with openai.AsyncOpenAI(**self._resolve_client_args()) as client: try: response = await client.responses.parse( model=self.get_config()["model_id"], diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 94e4caa3f..b43915b07 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -1,4 +1,5 @@ import logging +import os import unittest.mock import openai @@ -1710,3 +1711,159 @@ def test_format_request_messages_multiple_tool_calls_with_images(): }, ] assert tru_result == exp_result + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIModel +# ============================================================================= + + +class TestOpenAIModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + """bedrock_mantle_config produces the Mantle base_url and a minted bearer token as api_key.""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + + # Token is minted lazily per request, so inspect the resolved kwargs. + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + # Optional kwargs aren't forwarded so provide_token's own defaults apply. + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + """Optional credentials_provider and expiry are forwarded to provide_token.""" + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + """Each call to _resolve_client_args mints a fresh token (long-lived processes).""" + _ = openai_client + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 3 + + def test_bedrock_mantle_config_conflicts_with_custom_client(self, openai_client): + """Cannot pass both bedrock_mantle_config and a pre-built client.""" + _ = openai_client + custom_client = unittest.mock.Mock() + with pytest.raises(ValueError, match="bedrock_mantle_config"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client=custom_client, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + "default_headers": {"X-Trace-Id": "abc"}, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + assert resolved["default_headers"] == {"X-Trace-Id": "abc"} + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIModel( + model_id="openai.gpt-oss-120b", + client_args={"base_url": "https://custom.example.com"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_explicit_region_wins_over_boto_session(self, openai_client, mock_provide_token): + """``region`` takes precedence over a session's region.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"region": "us-east-1", "boto_session": session}, + ) + + model._resolve_client_args() + + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 88cbee326..b35d2d0de 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1,3 +1,4 @@ +import os import unittest.mock import openai @@ -1298,3 +1299,125 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + +# ============================================================================= +# Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel +# ============================================================================= + + +class TestOpenAIResponsesModelBedrockMantleConfig: + @pytest.fixture + def mock_provide_token(self): + with unittest.mock.patch("aws_bedrock_token_generator.provide_token") as mock: + mock.return_value = "bedrock-api-key-deadbeef&Version=1" + yield mock + + def test_bedrock_mantle_config_sets_base_url_and_api_key(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + mock_provide_token.assert_called_once_with(region="us-east-1") + + def test_bedrock_mantle_config_forwards_credentials_provider_and_expiry(self, openai_client, mock_provide_token): + _ = openai_client + from datetime import timedelta + + provider = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={ + "region": "us-west-2", + "credentials_provider": provider, + "expiry": timedelta(minutes=15), + }, + ) + model._resolve_client_args() + mock_provide_token.assert_called_once_with( + region="us-west-2", + aws_credentials_provider=provider, + expiry=timedelta(minutes=15), + ) + + def test_bedrock_mantle_config_mints_token_per_request(self, openai_client, mock_provide_token): + _ = openai_client + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + model._resolve_client_args() + model._resolve_client_args() + assert mock_provide_token.call_count == 2 + + def test_bedrock_mantle_config_merges_with_client_args(self, openai_client, mock_provide_token): + """bedrock_mantle_config composes with client_args; transport options are preserved.""" + _ = openai_client + sentinel_http_client = unittest.mock.Mock() + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={ + "timeout": 42, + "http_client": sentinel_http_client, + }, + bedrock_mantle_config={"region": "us-east-1"}, + ) + resolved = model._resolve_client_args() + assert resolved["base_url"] == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert resolved["api_key"] == "bedrock-api-key-deadbeef&Version=1" + assert resolved["timeout"] == 42 + assert resolved["http_client"] is sentinel_http_client + + def test_bedrock_mantle_config_rejects_base_url_in_client_args(self, openai_client): + """client_args must not contain base_url or api_key when bedrock_mantle_config is set.""" + _ = openai_client + with pytest.raises(ValueError, match="client_args must not contain"): + OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + client_args={"api_key": "should-not-be-here"}, + bedrock_mantle_config={"region": "us-east-1"}, + ) + + def test_bedrock_mantle_config_requires_region(self, openai_client): + """bedrock_mantle_config raises when no region can be resolved from config, session, or env.""" + _ = openai_client + with ( + unittest.mock.patch("boto3.Session") as mock_session_cls, + unittest.mock.patch.dict(os.environ, {}, clear=True), + ): + mock_session_cls.return_value.region_name = None + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + with pytest.raises(ValueError, match="Could not resolve an AWS region"): + model._resolve_client_args() + + def test_bedrock_mantle_config_region_resolved_from_boto3_default(self, openai_client, mock_provide_token): + """When region is omitted, the default boto3 session chain resolves it.""" + _ = openai_client + with unittest.mock.patch("boto3.Session") as mock_session_cls: + mock_session_cls.return_value.region_name = "eu-west-1" + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={}) + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.eu-west-1.api.aws/v1" + mock_provide_token.assert_called_once_with(region="eu-west-1") + + def test_bedrock_mantle_config_region_resolved_from_boto_session(self, openai_client, mock_provide_token): + """An explicit ``boto_session`` supplies the region when ``region`` is omitted.""" + _ = openai_client + session = unittest.mock.Mock() + session.region_name = "ap-southeast-2" + model = OpenAIResponsesModel( + model_id="openai.gpt-oss-120b", + bedrock_mantle_config={"boto_session": session}, + ) + + resolved = model._resolve_client_args() + + assert resolved["base_url"] == "https://bedrock-mantle.ap-southeast-2.api.aws/v1" + mock_provide_token.assert_called_once_with(region="ap-southeast-2") + + def test_bedrock_mantle_config_wraps_token_failures_with_context(self, openai_client, mock_provide_token): + """provide_token failures are wrapped in a RuntimeError with actionable context.""" + _ = openai_client + mock_provide_token.side_effect = RuntimeError("no credentials in chain") + model = OpenAIResponsesModel(model_id="openai.gpt-oss-120b", bedrock_mantle_config={"region": "us-east-1"}) + with pytest.raises(RuntimeError, match="Bedrock Mantle bearer token.*us-east-1"): + model._resolve_client_args() diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py index 1dc029344..7cc032146 100644 --- a/tests_integ/models/test_model_mantle.py +++ b/tests_integ/models/test_model_mantle.py @@ -1,61 +1,46 @@ -"""Integration tests for OpenAI Responses API on Bedrock Mantle with AWS credentials.""" +"""Integration tests for OpenAI-compatible APIs on Bedrock Mantle. + +Exercises the ``bedrock_mantle_config`` pathway on ``OpenAIModel`` (Chat Completions) and +``OpenAIResponsesModel`` (Responses API) against the live +``bedrock-mantle..api.aws/v1`` endpoint. Credentials come from the +ambient AWS credential chain; no explicit API key is passed by the user. +""" -import httpx import pytest -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.session import Session as BotocoreSession from strands import Agent +from strands.models.openai import OpenAIModel from strands.models.openai_responses import OpenAIResponsesModel +_REGION = "us-east-1" +_MODEL_ID = "openai.gpt-oss-120b" -class _SigV4Auth(httpx.Auth): - """httpx Auth handler that signs requests with AWS SigV4.""" - - def __init__(self, region: str): - session = BotocoreSession() - self.credentials = session.get_credentials().get_frozen_credentials() - self.signer = SigV4Auth(self.credentials, "bedrock", region) - - def auth_flow(self, request: httpx.Request): - aws_request = AWSRequest( - method=request.method, - url=str(request.url), - headers=dict(request.headers), - data=request.content, - ) - self.signer.add_auth(aws_request) - for key, value in aws_request.headers.items(): - request.headers[key] = value - yield request +@pytest.fixture +def bedrock_mantle_config(): + return {"region": _REGION} -class _NonClosingAsyncClient(httpx.AsyncClient): - """AsyncClient that survives the OpenAI SDK's context manager lifecycle.""" - async def aclose(self) -> None: - pass +@pytest.fixture +def chat_completions_model(bedrock_mantle_config): + return OpenAIModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) @pytest.fixture -def client_args(): - region = "us-east-1" - return { - "api_key": "unused", - "base_url": f"https://bedrock-mantle.{region}.api.aws/v1", - "http_client": _NonClosingAsyncClient(auth=_SigV4Auth(region)), - } +def model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, bedrock_mantle_config=bedrock_mantle_config) @pytest.fixture -def model(client_args): - return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", client_args=client_args) +def stateful_model(bedrock_mantle_config): + return OpenAIResponsesModel(model_id=_MODEL_ID, stateful=True, bedrock_mantle_config=bedrock_mantle_config) -@pytest.fixture -def stateful_model(client_args): - return OpenAIResponsesModel(model_id="openai.gpt-oss-120b", stateful=True, client_args=client_args) +def test_chat_completions_agent_invoke(chat_completions_model): + """OpenAIModel (Chat Completions) reaches Mantle via bedrock_mantle_config.""" + agent = Agent(model=chat_completions_model, system_prompt="Reply in one short sentence.", callback_handler=None) + result = agent("What is 2+2?") + assert "4" in str(result) def test_agent_invoke(model): @@ -74,11 +59,11 @@ def test_responses_server_side_conversation(stateful_model): assert "alice" in str(result).lower() -def test_reasoning_content_multi_turn(client_args): +def test_reasoning_content_multi_turn(bedrock_mantle_config): """Test that reasoning content from gpt-oss models doesn't break multi-turn conversations.""" model = OpenAIResponsesModel( - model_id="openai.gpt-oss-120b", - client_args=client_args, + model_id=_MODEL_ID, + bedrock_mantle_config=bedrock_mantle_config, params={"reasoning": {"effort": "low"}}, ) agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) From 8638fc2d629e32b7b5839f4c106d5aedcdf764c9 Mon Sep 17 00:00:00 2001 From: Aidan Daly <99039782+aidandaly24@users.noreply.github.com> Date: Mon, 4 May 2026 17:12:32 -0400 Subject: [PATCH 261/279] fix: include root cause in MCPClientInitializationError message (#2238) --- src/strands/tools/mcp/mcp_client.py | 2 +- tests/strands/tools/mcp/test_mcp_client.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2ac632925..1884ce9bc 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -220,7 +220,7 @@ def start(self) -> "MCPClient": logger.exception("client failed to initialize") # Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit self.stop(None, None, None) - raise MCPClientInitializationError("the client initialization failed") from e + raise MCPClientInitializationError(f"the client initialization failed: {e}") from e return self # ToolProvider interface methods diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index fe439c5d9..f270fa6fc 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -386,7 +386,10 @@ def test_enter_with_initialization_exception(mock_transport): client = MCPClient(mock_transport["transport_callable"]) with patch.object(client, "stop") as mock_stop: - with pytest.raises(MCPClientInitializationError, match="the client initialization failed"): + with pytest.raises( + MCPClientInitializationError, + match="the client initialization failed: Transport initialization failed", + ): client.start() # Verify stop() was called for cleanup From 559b2a05722b2bea10e63befde8027a9b279e945 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Tue, 5 May 2026 11:44:22 -0400 Subject: [PATCH 262/279] feat: add context window limit lookup table (#2249) --- src/strands/models/_defaults.py | 177 ++++++++++++++++++ src/strands/models/anthropic.py | 3 +- src/strands/models/bedrock.py | 3 +- src/strands/models/gemini.py | 3 +- src/strands/models/mistral.py | 3 +- src/strands/models/openai.py | 5 +- src/strands/models/openai_responses.py | 6 +- tests/strands/models/test_anthropic.py | 24 +++ tests/strands/models/test_bedrock.py | 40 ++++ tests/strands/models/test_defaults.py | 76 ++++++++ tests/strands/models/test_gemini.py | 24 +++ tests/strands/models/test_mistral.py | 24 +++ tests/strands/models/test_openai.py | 24 +++ tests/strands/models/test_openai_responses.py | 20 +- 14 files changed, 425 insertions(+), 7 deletions(-) create mode 100644 src/strands/models/_defaults.py create mode 100644 tests/strands/models/test_defaults.py diff --git a/src/strands/models/_defaults.py b/src/strands/models/_defaults.py new file mode 100644 index 000000000..e463b8ef6 --- /dev/null +++ b/src/strands/models/_defaults.py @@ -0,0 +1,177 @@ +"""Default model metadata lookup tables. + +Provides context window limits for known model IDs across all providers. +Values sourced from provider documentation and +https://github.com/BerriAI/litellm/blob/litellm_internal_staging/model_prices_and_context_window.json + +Applied to providers with well-known, fixed model IDs: Bedrock, Anthropic, OpenAI, +OpenAI Responses, Gemini, and Mistral. Providers that use local/custom model IDs +(Ollama, LlamaCpp, SageMaker) or proxy to other providers with their own prefixed +ID format (LiteLLM) are excluded — their context windows depend on deployment config, +not a static table. +""" + +import logging +from collections.abc import Mapping +from typing import TypeVar + +logger = logging.getLogger(__name__) + +_C = TypeVar("_C", bound=Mapping[str, object]) + +# Context window limits (in tokens) for known model IDs. +# +# Best-effort lookup table — unknown models return None and callers +# fall back gracefully (e.g. proactive compression is disabled). +# Users can always override with an explicit context_window_limit in their model config. +# +# For Bedrock models with cross-region prefixes (e.g. us., eu., global.), +# get_context_window_limit strips the prefix before lookup so only the base model ID is needed here. +_CONTEXT_WINDOW_LIMITS: dict[str, int] = { + # Anthropic (direct API) + "claude-sonnet-4-6": 1_000_000, + "claude-sonnet-4-20250514": 1_000_000, + "claude-sonnet-4-5": 200_000, + "claude-sonnet-4-5-20250929": 200_000, + "claude-opus-4-6": 1_000_000, + "claude-opus-4-6-20260205": 1_000_000, + "claude-opus-4-7": 1_000_000, + "claude-opus-4-7-20260416": 1_000_000, + "claude-opus-4-5": 200_000, + "claude-opus-4-5-20251101": 200_000, + "claude-opus-4-20250514": 200_000, + "claude-opus-4-1": 200_000, + "claude-opus-4-1-20250805": 200_000, + "claude-haiku-4-5": 200_000, + "claude-haiku-4-5-20251001": 200_000, + "claude-3-7-sonnet-20250219": 200_000, + "claude-3-5-sonnet-20241022": 200_000, + "claude-3-5-sonnet-20240620": 200_000, + "claude-3-5-haiku-20241022": 200_000, + "claude-3-opus-20240229": 200_000, + "claude-3-haiku-20240307": 200_000, + # Bedrock Anthropic (base model IDs — cross-region prefixes stripped by get_context_window_limit) + "anthropic.claude-sonnet-4-6": 1_000_000, + "anthropic.claude-sonnet-4-20250514-v1:0": 1_000_000, + "anthropic.claude-sonnet-4-5-20250929-v1:0": 200_000, + "anthropic.claude-opus-4-6-v1": 1_000_000, + "anthropic.claude-opus-4-7": 1_000_000, + "anthropic.claude-opus-4-5-20251101-v1:0": 200_000, + "anthropic.claude-opus-4-20250514-v1:0": 200_000, + "anthropic.claude-opus-4-1-20250805-v1:0": 200_000, + "anthropic.claude-haiku-4-5-20251001-v1:0": 200_000, + "anthropic.claude-haiku-4-5@20251001": 200_000, + "anthropic.claude-3-7-sonnet-20250219-v1:0": 200_000, + "anthropic.claude-3-7-sonnet-20240620-v1:0": 200_000, + "anthropic.claude-3-5-sonnet-20241022-v2:0": 200_000, + "anthropic.claude-3-5-sonnet-20240620-v1:0": 200_000, + "anthropic.claude-3-5-haiku-20241022-v1:0": 200_000, + "anthropic.claude-3-opus-20240229-v1:0": 200_000, + "anthropic.claude-3-haiku-20240307-v1:0": 200_000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200_000, + "anthropic.claude-mythos-preview": 1_000_000, + # Bedrock Amazon Nova + "amazon.nova-pro-v1:0": 300_000, + "amazon.nova-lite-v1:0": 300_000, + "amazon.nova-micro-v1:0": 128_000, + "amazon.nova-premier-v1:0": 1_000_000, + "amazon.nova-2-lite-v1:0": 1_000_000, + "amazon.nova-2-pro-preview-20251202-v1:0": 1_000_000, + # OpenAI + "gpt-5.5": 1_050_000, + "gpt-5.5-pro": 1_050_000, + "gpt-5.4": 1_050_000, + "gpt-5.4-pro": 1_050_000, + "gpt-5.4-mini": 272_000, + "gpt-5.4-nano": 272_000, + "gpt-5.2": 272_000, + "gpt-5.2-pro": 272_000, + "gpt-5.1": 272_000, + "gpt-5": 272_000, + "gpt-5-mini": 272_000, + "gpt-5-nano": 272_000, + "gpt-5-pro": 128_000, + "gpt-4.1": 1_047_576, + "gpt-4.1-mini": 1_047_576, + "gpt-4.1-nano": 1_047_576, + "gpt-4o": 128_000, + "gpt-4o-mini": 128_000, + "gpt-4-turbo": 128_000, + "o3": 200_000, + "o3-mini": 200_000, + "o3-pro": 200_000, + "o4-mini": 200_000, + "o1": 200_000, + # Google Gemini + "gemini-2.5-flash": 1_048_576, + "gemini-2.5-flash-lite": 1_048_576, + "gemini-2.5-pro": 1_048_576, + "gemini-2.0-flash": 1_048_576, + "gemini-2.0-flash-lite": 1_048_576, + "gemini-3-pro-preview": 1_048_576, + "gemini-3-flash-preview": 1_048_576, + "gemini-3.1-pro-preview": 1_048_576, + "gemini-3.1-flash-lite-preview": 1_048_576, + # Mistral + "mistral-large-latest": 262_144, + "mistral-large-2512": 262_144, + "mistral-large-3": 262_144, + "mistral-medium-latest": 131_072, + "mistral-medium-2505": 131_072, + "mistral-small-latest": 131_072, + "mistral-small-3-2-2506": 131_072, +} + + +def get_context_window_limit(model_id: str) -> int | None: + """Look up the context window limit for a model ID. + + For Bedrock cross-region model IDs (e.g. ``us.anthropic.claude-sonnet-4-6``), + the region prefix is stripped as a fallback if the direct lookup fails. + + Args: + model_id: The model ID to look up. + + Returns: + The context window limit in tokens, or None if not found. + """ + direct = _CONTEXT_WINDOW_LIMITS.get(model_id) + if direct is not None: + return direct + + # Fallback: strip prefix before first dot and retry (handles cross-region prefixes) + dot_index = model_id.find(".") + if dot_index != -1: + stripped = model_id[dot_index + 1 :] + result = _CONTEXT_WINDOW_LIMITS.get(stripped) + if result is not None: + logger.debug( + "model_id=<%s>, stripped_id=<%s> | resolved context window limit via prefix strip", model_id, stripped + ) + return result + + return None + + +def resolve_config_metadata(config: _C, model_id: str) -> _C: + """Resolve model metadata fields on a config dict from built-in lookup tables. + + When ``context_window_limit`` is not explicitly set, looks it up from the built-in table. + Explicit values pass through unchanged. Returns a new dict only when resolution adds a field; + otherwise returns the original config to avoid unnecessary allocation. + + Args: + config: The stored model config dict. + model_id: The model ID to look up. + + Returns: + The config with resolved metadata, or the original config if nothing to resolve. + """ + if "context_window_limit" in config: + return config + + limit = get_context_window_limit(model_id) + if limit is None: + return config + + return {**config, "context_window_limit": limit} # type: ignore[return-value] diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 54fdaaf00..ece7cd8d1 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -20,6 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec +from ._defaults import resolve_config_metadata from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -95,7 +96,7 @@ def get_config(self) -> AnthropicConfig: Returns: The Anthropic model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: """Format an Anthropic content block. diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d535bbc51..baa2807c4 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -31,6 +31,7 @@ ) from ..types.streaming import CitationsDelta, StreamEvent from ..types.tools import ToolChoice, ToolSpec +from ._defaults import resolve_config_metadata from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys from .model import BaseModelConfig, CacheConfig, Model @@ -217,7 +218,7 @@ def get_config(self) -> BedrockConfig: Returns: The Bedrock model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config.get("model_id", "")) def _format_request( self, diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 892dce52d..65b925c6d 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -19,6 +19,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException, ProviderTokenCountError from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec +from ._defaults import resolve_config_metadata from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -115,7 +116,7 @@ def get_config(self) -> GeminiConfig: Returns: The Gemini model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) def _get_client(self) -> genai.Client: """Get a Gemini client for making requests. diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index c4a23b244..2ae00cef9 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -17,6 +17,7 @@ from ..types.exceptions import ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._defaults import resolve_config_metadata from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported from .model import BaseModelConfig, Model @@ -114,7 +115,7 @@ def get_config(self) -> MistralConfig: Returns: The Mistral model configuration. """ - return self.config + return resolve_config_metadata(self.config, self.config["model_id"]) def _format_request_message_content(self, content: ContentBlock) -> str | dict[str, Any]: """Format a Mistral content block. diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index ea16c7713..94d4b0b90 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -21,6 +21,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._defaults import resolve_config_metadata from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -150,7 +151,9 @@ def get_config(self) -> OpenAIConfig: Returns: The OpenAI model configuration. """ - return cast(OpenAIModel.OpenAIConfig, self.config) + return cast( + OpenAIModel.OpenAIConfig, resolve_config_metadata(self.config, str(self.config.get("model_id", ""))) + ) @classmethod def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 4aff07ccd..a78cef73a 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -58,6 +58,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._defaults import resolve_config_metadata # noqa: E402 from ._openai_bedrock import BedrockMantleConfig, resolve_bedrock_client_args # noqa: E402 from ._validation import validate_config_keys # noqa: E402 from .model import BaseModelConfig, Model # noqa: E402 @@ -210,7 +211,10 @@ def get_config(self) -> OpenAIResponsesConfig: Returns: The OpenAI Responses API model configuration. """ - return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) + return cast( + OpenAIResponsesModel.OpenAIResponsesConfig, + resolve_config_metadata(self.config, str(self.config.get("model_id", ""))), + ) @override async def count_tokens( diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 8e004dbb7..abb56a441 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -82,6 +82,30 @@ def test__init__model_configs(anthropic_client, model_id, max_tokens): assert tru_temperature == exp_temperature +def test__init__auto_populates_context_window_limit(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="claude-sonnet-4-20250514", max_tokens=1) + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__explicit_context_window_limit_not_overridden(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="claude-sonnet-4-20250514", max_tokens=1, context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(anthropic_client): + _ = anthropic_client + + model = AnthropicModel(model_id="unknown-model", max_tokens=1) + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a80ca091e..e42fc8e1f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -296,6 +296,46 @@ def test__init__context_window_limit(bedrock_client): assert model.context_window_limit == 200_000 +def test__init__auto_populates_context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__auto_populates_context_window_limit_cross_region(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-6") + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__auto_populates_context_window_limit_default_model(bedrock_client): + _ = bedrock_client + + model = BedrockModel() + + assert model.get_config().get("context_window_limit") == 1_000_000 + + +def test__init__explicit_context_window_limit_not_overridden(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0", context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(bedrock_client): + _ = bedrock_client + + model = BedrockModel(model_id="unknown.model-v1:0") + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_defaults.py b/tests/strands/models/test_defaults.py new file mode 100644 index 000000000..94c602fc1 --- /dev/null +++ b/tests/strands/models/test_defaults.py @@ -0,0 +1,76 @@ +"""Tests for model metadata lookup tables.""" + +from strands.models._defaults import get_context_window_limit, resolve_config_metadata + + +class TestGetContextWindowLimit: + """Tests for get_context_window_limit.""" + + def test_known_anthropic_direct_api(self): + assert get_context_window_limit("claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("claude-opus-4-6") == 1_000_000 + assert get_context_window_limit("claude-opus-4-5") == 200_000 + assert get_context_window_limit("claude-haiku-4-5") == 200_000 + + def test_known_bedrock_anthropic(self): + assert get_context_window_limit("anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("anthropic.claude-haiku-4-5-20251001-v1:0") == 200_000 + + def test_known_bedrock_nova(self): + assert get_context_window_limit("amazon.nova-pro-v1:0") == 300_000 + assert get_context_window_limit("amazon.nova-micro-v1:0") == 128_000 + + def test_known_openai(self): + assert get_context_window_limit("gpt-5.4") == 1_050_000 + assert get_context_window_limit("gpt-4o") == 128_000 + assert get_context_window_limit("o3") == 200_000 + assert get_context_window_limit("o4-mini") == 200_000 + + def test_known_gemini(self): + assert get_context_window_limit("gemini-2.5-flash") == 1_048_576 + assert get_context_window_limit("gemini-2.5-pro") == 1_048_576 + + def test_strips_bedrock_cross_region_prefix(self): + assert get_context_window_limit("us.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("global.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("eu.anthropic.claude-sonnet-4-6") == 1_000_000 + assert get_context_window_limit("ap.anthropic.claude-sonnet-4-6") == 1_000_000 + + def test_strips_any_prefix_as_fallback(self): + # Any prefix before the first dot is stripped if direct lookup fails + assert get_context_window_limit("custom.anthropic.claude-sonnet-4-6") == 1_000_000 + + def test_unknown_model_returns_none(self): + assert get_context_window_limit("unknown-model-xyz") is None + assert get_context_window_limit("foo.unknown-model-xyz") is None + + +class TestResolveConfigMetadata: + """Tests for resolve_config_metadata.""" + + def test_resolves_context_window_limit(self): + config: dict = {"model_id": "claude-sonnet-4-6"} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result["context_window_limit"] == 1_000_000 + + def test_preserves_explicit_context_window_limit(self): + config: dict = {"model_id": "claude-sonnet-4-6", "context_window_limit": 100_000} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result["context_window_limit"] == 100_000 + + def test_returns_original_config_when_explicit(self): + config: dict = {"model_id": "claude-sonnet-4-6", "context_window_limit": 100_000} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result is config + + def test_returns_original_config_when_unknown_model(self): + config: dict = {"model_id": "unknown-model"} + result = resolve_config_metadata(config, "unknown-model") + assert result is config + assert "context_window_limit" not in result + + def test_returns_new_dict_when_resolved(self): + config: dict = {"model_id": "claude-sonnet-4-6"} + result = resolve_config_metadata(config, "claude-sonnet-4-6") + assert result is not config + assert "context_window_limit" not in config diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index fe6936ccc..91a55d899 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -79,6 +79,30 @@ def test__init__context_window_limit(gemini_client): assert model.context_window_limit == 1_048_576 +def test__init__auto_populates_context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash") + + assert model.get_config().get("context_window_limit") == 1_048_576 + + +def test__init__explicit_context_window_limit_not_overridden(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="gemini-2.5-flash", context_window_limit=500_000) + + assert model.get_config().get("context_window_limit") == 500_000 + + +def test__init__unknown_model_no_context_window_limit(gemini_client): + _ = gemini_client + + model = GeminiModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index 57189748e..dd2728785 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -80,6 +80,30 @@ def test__init__model_configs(mistral_client, model_id, max_tokens): assert actual_temperature == exp_temperature +def test__init__auto_populates_context_window_limit(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="mistral-large-latest", max_tokens=1) + + assert model.get_config().get("context_window_limit") == 262_144 + + +def test__init__explicit_context_window_limit_not_overridden(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="mistral-large-latest", max_tokens=1, context_window_limit=100_000) + + assert model.get_config().get("context_window_limit") == 100_000 + + +def test__init__unknown_model_no_context_window_limit(mistral_client): + _ = mistral_client + + model = MistralModel(model_id="unknown-model", max_tokens=1) + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index b43915b07..613acd163 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -99,6 +99,30 @@ def test__init__context_window_limit(openai_client): assert model.context_window_limit == 128_000 +def test__init__auto_populates_context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o") + + assert model.get_config().get("context_window_limit") == 128_000 + + +def test__init__explicit_context_window_limit_not_overridden(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="gpt-4o", context_window_limit=50_000) + + assert model.get_config().get("context_window_limit") == 50_000 + + +def test__init__unknown_model_no_context_window_limit(openai_client): + _ = openai_client + + model = OpenAIModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + @pytest.mark.parametrize( "content, exp_result", [ diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index b35d2d0de..97ee9e305 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -71,11 +71,29 @@ def test__init__(model_id): model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) tru_config = model.get_config() - exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}} + exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}, "context_window_limit": 128_000} assert tru_config == exp_config +def test__init__auto_populates_context_window_limit(): + model = OpenAIResponsesModel(model_id="gpt-4o") + + assert model.get_config().get("context_window_limit") == 128_000 + + +def test__init__explicit_context_window_limit_not_overridden(): + model = OpenAIResponsesModel(model_id="gpt-4o", context_window_limit=50_000) + + assert model.get_config().get("context_window_limit") == 50_000 + + +def test__init__unknown_model_no_context_window_limit(): + model = OpenAIResponsesModel(model_id="unknown-model") + + assert model.get_config().get("context_window_limit") is None + + def test_update_config(model, model_id): model.update_config(model_id=model_id) From d94d5163daae2698a284cdcbe9a89c43d0b22e0a Mon Sep 17 00:00:00 2001 From: mehtarac Date: Wed, 6 May 2026 10:29:23 -0400 Subject: [PATCH 263/279] fix: fix count tokens for bedrock models (#2254) --- src/strands/models/model.py | 73 +------------- .../context_offloader/plugin.py | 11 +-- tests/strands/models/test_model.py | 94 +++++-------------- 3 files changed, 26 insertions(+), 152 deletions(-) diff --git a/src/strands/models/model.py b/src/strands/models/model.py index e5b15ebaa..3ded11a28 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,7 +1,6 @@ """Abstract base class for Agent model providers.""" import abc -import functools import json import logging import math @@ -24,9 +23,6 @@ T = TypeVar("T", bound=BaseModel) -_DEFAULT_ENCODING = "cl100k_base" - - def _heuristic_estimate_text(text: str) -> int: """Estimate token count from text using characters / 4 heuristic.""" return math.ceil(len(text) / 4) @@ -40,22 +36,6 @@ def _heuristic_estimate_json(obj: Any) -> int: return 0 -@functools.lru_cache(maxsize=1) -def _get_encoding() -> Any: - """Get the default tiktoken encoding, caching to avoid repeated lookups. - - Returns: - The tiktoken encoding, or None if tiktoken is not installed. - """ - try: - import tiktoken - - return tiktoken.get_encoding(_DEFAULT_ENCODING) - except ImportError: - logger.debug("tiktoken not available, falling back to heuristic token estimation") - return None - - def _count_content_block_tokens( block: ContentBlock, count_text: Callable[[str], int], count_json: Callable[[Any], int] ) -> int: @@ -104,54 +84,6 @@ def _count_content_block_tokens( return total -def _estimate_tokens_with_tiktoken( - messages: Messages, - tool_specs: list[ToolSpec] | None = None, - system_prompt: str | None = None, - system_prompt_content: list[SystemContentBlock] | None = None, -) -> int: - """Estimate tokens by serializing messages/tools to text and counting with tiktoken. - - This is a best-effort fallback for providers that don't expose native counting. - Accuracy varies by model but is sufficient for threshold-based decisions. - - Raises: - ImportError: If tiktoken is not installed. - """ - encoding = _get_encoding() - if encoding is None: - raise ImportError("tiktoken is not available") - - def count_text(text: str) -> int: - return len(encoding.encode(text)) - - def count_json(obj: Any) -> int: - try: - return len(encoding.encode(json.dumps(obj))) - except (TypeError, ValueError): - return 0 - - total = 0 - - # Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting, - # since providers wrap system_prompt into system_prompt_content when both are provided. - if system_prompt_content: - for block in system_prompt_content: - if "text" in block: - total += count_text(block["text"]) - elif system_prompt: - total += count_text(system_prompt) - - for message in messages: - for block in message["content"]: - total += _count_content_block_tokens(block, count_text, count_json) - - if tool_specs: - for spec in tool_specs: - total += count_json(spec) - - return total - def _estimate_tokens_with_heuristic( messages: Messages, @@ -338,10 +270,7 @@ async def count_tokens( Returns: Estimated total input tokens. """ - try: - return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content) - except ImportError: - return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content) + return _estimate_tokens_with_heuristic(messages, tool_specs, system_prompt, system_prompt_content) class _ModelPlugin(Plugin): diff --git a/src/strands/vended_plugins/context_offloader/plugin.py b/src/strands/vended_plugins/context_offloader/plugin.py index 929ba3ca6..6cb98b31b 100644 --- a/src/strands/vended_plugins/context_offloader/plugin.py +++ b/src/strands/vended_plugins/context_offloader/plugin.py @@ -37,7 +37,6 @@ from typing import TYPE_CHECKING from ...hooks.events import AfterToolCallEvent -from ...models.model import _get_encoding from ...plugins import Plugin, hook from ...tools.decorator import tool from ...types.content import Message @@ -318,10 +317,7 @@ async def _handle_tool_result(self, event: AfterToolCallEvent) -> None: ) def _slice_preview(self, text: str) -> str: - """Slice text to approximately preview_tokens. - - Uses tiktoken for exact token-level slicing when available, - falls back to characters (tokens * 4) otherwise. + """Slice text to approximately preview_tokens using character-based estimation. Args: text: The full text to slice. @@ -329,9 +325,4 @@ def _slice_preview(self, text: str) -> str: Returns: The preview text. """ - encoding = _get_encoding() - if encoding is not None: - tokens = encoding.encode(text) - preview: str = encoding.decode(tokens[: self._preview_tokens]) - return preview return text[: self._preview_tokens * _CHARS_PER_TOKEN] diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index 2c685b43b..b362740b5 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -244,35 +244,35 @@ async def test_count_tokens_empty_messages(model): @pytest.mark.asyncio async def test_count_tokens_system_prompt_only(model): result = await model.count_tokens(messages=[], system_prompt="You are a helpful assistant.") - assert result == 6 + assert result == 7 # ceil(28/4) @pytest.mark.asyncio async def test_count_tokens_text_messages(model, messages): result = await model.count_tokens(messages=messages) - assert result == 1 # "hello" + assert result == 2 # ceil(5/4) @pytest.mark.asyncio async def test_count_tokens_with_tool_specs(model, messages, tool_specs): without_tools = await model.count_tokens(messages=messages) with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs) - assert without_tools == 1 # "hello" - assert with_tools == 49 # "hello" (1) + tool_spec (48) + assert without_tools == 2 # ceil(5/4) + assert with_tools == 84 # ceil(5/4) + ceil(164/2) @pytest.mark.asyncio async def test_count_tokens_with_system_prompt(model, messages, system_prompt): without_prompt = await model.count_tokens(messages=messages) with_prompt = await model.count_tokens(messages=messages, system_prompt=system_prompt) - assert without_prompt == 1 # "hello" - assert with_prompt == 3 # "hello" (1) + "s1" (2) + assert without_prompt == 2 # ceil(5/4) + assert with_prompt == 3 # ceil(5/4) + ceil(2/4) @pytest.mark.asyncio async def test_count_tokens_combined(model, messages, tool_specs, system_prompt): result = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt=system_prompt) - assert result == 51 # "hello" (1) + tool_spec (48) + "s1" (2) + assert result == 85 # ceil(5/4) + ceil(164/2) + ceil(2/4) @pytest.mark.asyncio @@ -292,8 +292,8 @@ async def test_count_tokens_tool_use_block(model): } ] result = await model.count_tokens(messages=messages) - # name "my_tool" (2) + json.dumps(input) (6) = 8 - assert result == 8 + # name "my_tool" ceil(7/4)=2 + json.dumps(input) ceil(17/2)=9 = 11 + assert result == 11 @pytest.mark.asyncio @@ -313,7 +313,7 @@ async def test_count_tokens_tool_result_block(model): } ] result = await model.count_tokens(messages=messages) - assert result == 3 # "tool output here" + assert result == 4 # ceil(16/4) @pytest.mark.asyncio @@ -333,7 +333,7 @@ async def test_count_tokens_reasoning_block(model): } ] result = await model.count_tokens(messages=messages) - assert result == 9 # "Let me think about this step by step." + assert result == 10 # ceil(37/4) @pytest.mark.asyncio @@ -399,7 +399,7 @@ async def test_count_tokens_guard_content_block(model): } ] result = await model.count_tokens(messages=messages) - assert result == 8 # "This content was filtered by guardrails." + assert result == 10 # ceil(40/4) @pytest.mark.asyncio @@ -420,7 +420,7 @@ async def test_count_tokens_tool_use_with_bytes(model): ] result = await model.count_tokens(messages=messages) # Should still count the tool name even though input has non-serializable bytes - assert result == 2 # "my_tool" name only + assert result == 2 # ceil(7/4) name only @pytest.mark.asyncio @@ -434,7 +434,7 @@ async def test_count_tokens_non_serializable_tool_spec(model, messages): ] result = await model.count_tokens(messages=messages, tool_specs=tool_specs) # Should still count the message tokens even though tool spec fails - assert result == 1 # "hello" only, tool spec skipped + assert result == 2 # ceil(5/4) only, tool spec skipped @pytest.mark.asyncio @@ -453,7 +453,7 @@ async def test_count_tokens_citations_block(model): } ] result = await model.count_tokens(messages=messages) - assert result == 11 # "According to the document, the answer is 42." + assert result == 11 # ceil(44/4) @pytest.mark.asyncio @@ -462,7 +462,7 @@ async def test_count_tokens_system_prompt_content(model): messages=[], system_prompt_content=[{"text": "You are a helpful assistant."}], ) - assert result == 6 # "You are a helpful assistant." + assert result == 7 # ceil(28/4) @pytest.mark.asyncio @@ -474,7 +474,7 @@ async def test_count_tokens_system_prompt_content_with_cache_point(model): {"cachePoint": {"type": "default"}}, ], ) - assert result == 6 # "You are a helpful assistant.", cachePoint adds 0 + assert result == 7 # ceil(28/4), cachePoint adds 0 @pytest.mark.asyncio @@ -489,7 +489,7 @@ async def test_count_tokens_system_prompt_content_takes_priority(model): system_prompt="This is a much longer system prompt that should have more tokens.", system_prompt_content=[{"text": "Short."}], ) - assert content_only == 2 # "Short." + assert content_only == 2 # ceil(6/4) assert content_only == both @@ -505,41 +505,10 @@ async def test_count_tokens_all_inputs(model): system_prompt="Be helpful.", system_prompt_content=[{"text": "Additional system context."}], ) - # system_prompt_content (4) + "hello world" (2) + "hi there" (2) + tool_spec (23) = 31 - assert result == 31 + # system_prompt_content (7) + "hello world" (3) + "hi there" (2) + tool_spec (38) = 50 + assert result == 50 -def test__get_encoding_falls_back_without_tiktoken(monkeypatch): - """Test that _get_encoding returns None and count_tokens falls back to heuristic.""" - import strands.models.model as model_module - - model_module._get_encoding.cache_clear() - original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ - - def _block_tiktoken(name, *args, **kwargs): - if name == "tiktoken": - raise ImportError("No module named 'tiktoken'") - return original_import(name, *args, **kwargs) - - monkeypatch.setattr("builtins.__import__", _block_tiktoken) - - try: - assert model_module._get_encoding() is None - - # _estimate_tokens_with_tiktoken should raise when tiktoken is unavailable - with pytest.raises(ImportError): - model_module._estimate_tokens_with_tiktoken( - messages=[{"role": "user", "content": [{"text": "hello world!"}]}], - ) - - # _estimate_tokens_with_heuristic uses chars/4 for text - result = model_module._estimate_tokens_with_heuristic( - messages=[{"role": "user", "content": [{"text": "hello world!"}]}], - ) - assert result == 3 # ceil(12 / 4) - finally: - model_module._get_encoding.cache_clear() - class TestHeuristicEstimation: """Tests for _estimate_tokens_with_heuristic.""" @@ -592,22 +561,7 @@ def test_non_serializable_inputs(self): assert result == 2 # only tool name counted: ceil(len("my_tool") / 4) @pytest.mark.asyncio - async def test_model_falls_back_to_heuristic(self, monkeypatch, model): - """Model.count_tokens falls back to heuristic when tiktoken unavailable.""" - import strands.models.model as model_module - - model_module._get_encoding.cache_clear() - original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ - - def _block_tiktoken(name, *args, **kwargs): - if name == "tiktoken": - raise ImportError("No module named 'tiktoken'") - return original_import(name, *args, **kwargs) - - monkeypatch.setattr("builtins.__import__", _block_tiktoken) - - try: - result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) - assert result == 3 # ceil(12 / 4) - finally: - model_module._get_encoding.cache_clear() + async def test_model_uses_heuristic(self, model): + """Model.count_tokens uses heuristic estimation.""" + result = await model.count_tokens(messages=[{"role": "user", "content": [{"text": "hello world!"}]}]) + assert result == 3 # ceil(12 / 4) From 6b0df9add0000d75b22cbf6ecc5c097c935c0d72 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 6 May 2026 11:41:20 -0400 Subject: [PATCH 264/279] fix: cache unsupported models for bedrocks token counting (#2250) --- src/strands/models/bedrock.py | 36 ++++++++++++++++++++++++---- tests/strands/models/test_bedrock.py | 35 +++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index baa2807c4..c1cbfa265 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -55,6 +55,15 @@ "anthropic.claude", ] +# Cache of model IDs that do not support the CountTokens API. +_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set() + + +def _clear_unsupported_count_tokens_cache() -> None: + """Clear the cache of model IDs that do not support the CountTokens API.""" + _UNSUPPORTED_COUNT_TOKENS_MODELS.clear() + + T = TypeVar("T", bound=BaseModel) DEFAULT_READ_TIMEOUT = 120 @@ -785,6 +794,11 @@ async def count_tokens( Returns: Total input token count. """ + model_id: str = self.config["model_id"] + + if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: if system_prompt and system_prompt_content is None: system_prompt_content = [{"text": system_prompt}] @@ -811,11 +825,23 @@ async def count_tokens( logger.debug("model_id=<%s>, total_tokens=<%d> | native token count", self.config["model_id"], total_tokens) return total_tokens except Exception as e: - logger.debug( - "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", - self.config["model_id"], - e, - ) + if ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "ValidationException" + and "doesn't support counting tokens" in str(e) + ): + logger.debug( + "model_id=<%s> | model does not support CountTokens, caching for future calls," + " falling back to estimation", + model_id, + ) + _UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id) + else: + logger.debug( + "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", + model_id, + e, + ) return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) @override diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index e42fc8e1f..f177a8a17 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,6 +19,7 @@ DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, + _clear_unsupported_count_tokens_cache, ) from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec @@ -3333,6 +3334,12 @@ async def test_non_streaming_citations_with_only_location(bedrock_client, model, class TestCountTokens: """Tests for BedrockModel.count_tokens native token counting.""" + @pytest.fixture(autouse=True) + def clean_cache(self): + _clear_unsupported_count_tokens_cache() + yield + _clear_unsupported_count_tokens_cache() + @pytest.fixture def model_with_client(self, bedrock_client, model_id): _ = bedrock_client @@ -3449,3 +3456,31 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): + model = BedrockModel(model_id="unsupported-cache-test-model") + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call: skips API entirely + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + @pytest.mark.asyncio + async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): + model = BedrockModel(model_id="transient-error-test-model") + bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") + + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 1 + + # Second call should still attempt the API + await model.count_tokens(messages=messages) + assert bedrock_client.count_tokens.call_count == 2 From 800e7c46614a097d1bcacb4c0128e0d0a3618de7 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 6 May 2026 16:51:04 -0400 Subject: [PATCH 265/279] feat: add useNativeTokenCount flag to skip token counting API calls (#2255) --- src/strands/models/anthropic.py | 7 +++++++ src/strands/models/bedrock.py | 7 +++++++ src/strands/models/gemini.py | 7 +++++++ src/strands/models/llamacpp.py | 7 +++++++ src/strands/models/model.py | 2 +- src/strands/models/openai_responses.py | 7 +++++++ tests/strands/models/test_anthropic.py | 13 +++++++++++++ tests/strands/models/test_bedrock.py | 11 +++++++++++ tests/strands/models/test_gemini.py | 11 +++++++++++ tests/strands/models/test_llamacpp.py | 11 +++++++++++ tests/strands/models/test_model.py | 1 - tests/strands/models/test_openai_responses.py | 11 +++++++++++ 12 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index ece7cd8d1..04fae220d 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -57,11 +57,15 @@ class AnthropicConfig(BaseModelConfig, total=False): https://docs.anthropic.com/en/docs/about-claude/models/all-models. params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. + use_native_token_count: Whether to use the native Anthropic count_tokens API. + When True (default), count_tokens() calls the Anthropic API for accurate counts. + When False, skips the API call and uses the local estimator. """ max_tokens: Required[int] model_id: Required[str] params: dict[str, Any] | None + use_native_token_count: bool def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. @@ -394,6 +398,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c1cbfa265..c74a63a3b 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -117,6 +117,9 @@ class BedrockConfig(BaseModelConfig, total=False): See https://docs.aws.amazon.com/bedrock/latest/userguide/structured-output.html temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) + use_native_token_count: Whether to use the native Bedrock CountTokens API. + When True (default), count_tokens() calls the Bedrock API for accurate counts. + When False, skips the API call and uses the local estimator. """ additional_args: dict[str, Any] | None @@ -143,6 +146,7 @@ class BedrockConfig(BaseModelConfig, total=False): strict_tools: bool | None temperature: float | None top_p: float | None + use_native_token_count: bool def __init__( self, @@ -794,6 +798,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + model_id: str = self.config["model_id"] if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 65b925c6d..8ed579d38 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -49,11 +49,15 @@ class GeminiConfig(BaseModelConfig, total=False): Use the standard tools interface for function calling tools. For a complete list of supported tools, see https://ai.google.dev/api/caching#Tool + use_native_token_count: Whether to use the native Gemini count_tokens API. + When True (default), count_tokens() calls the Gemini API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: Required[str] params: dict[str, Any] gemini_tools: list[genai.types.Tool] + use_native_token_count: bool def __init__( self, @@ -457,6 +461,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: contents = list(self._format_request_content(messages)) diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c31ba11bc..531cf6b50 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -125,10 +125,14 @@ class LlamaCppConfig(BaseModelConfig, total=False): - cache_prompt: Cache the prompt for faster generation - slot_id: Slot ID for parallel inference - samplers: Custom sampler order + use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint. + When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None + use_native_token_count: bool def __init__( self, @@ -533,6 +537,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/src/strands/models/model.py b/src/strands/models/model.py index 3ded11a28..dd2f9eed2 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -23,6 +23,7 @@ T = TypeVar("T", bound=BaseModel) + def _heuristic_estimate_text(text: str) -> int: """Estimate token count from text using characters / 4 heuristic.""" return math.ceil(len(text) / 4) @@ -84,7 +85,6 @@ def _count_content_block_tokens( return total - def _estimate_tokens_with_heuristic( messages: Messages, tool_specs: list[ToolSpec] | None = None, diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index a78cef73a..c6ddbb9d6 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -136,11 +136,15 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): stateful: Whether to enable server-side conversation state management. When True, the server stores conversation history and the client does not need to send the full message history with each request. Defaults to False. + use_native_token_count: Whether to use the native OpenAI input_tokens.count API. + When True (default), count_tokens() calls the OpenAI API for accurate counts. + When False, skips the API call and uses the local estimator. """ model_id: str params: dict[str, Any] | None stateful: bool + use_native_token_count: bool def __init__( self, @@ -238,6 +242,9 @@ async def count_tokens( Returns: Total input token count. """ + if self.config.get("use_native_token_count") is False: + return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) + try: # system_prompt_content is not used; this provider only accepts system_prompt as a plain string, # matching the behavior of stream(). The caller always provides system_prompt alongside diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index abb56a441..6de821e90 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1162,3 +1162,16 @@ async def test_fallback_logs_debug(self, model_with_client, anthropic_client, me await model_with_client.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false( + self, anthropic_client, model_id, max_tokens, messages + ): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index f177a8a17..2f1f7d1f1 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3484,3 +3484,14 @@ async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, me # Second call should still attempt the API await model.count_tokens(messages=messages) assert bedrock_client.count_tokens.call_count == 2 + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id, use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 91a55d899..b846bfcdf 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1228,3 +1228,14 @@ async def test_fallback_logs_debug(self, model, gemini_client, messages, caplog) await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index a891ec929..43fb03629 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -803,3 +803,14 @@ async def test_fallback_logs_debug(self, model, messages, caplog): await model.count_tokens(messages=messages) assert any("native token counting failed" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=False) + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_model.py b/tests/strands/models/test_model.py index b362740b5..34f4ef328 100644 --- a/tests/strands/models/test_model.py +++ b/tests/strands/models/test_model.py @@ -509,7 +509,6 @@ async def test_count_tokens_all_inputs(model): assert result == 50 - class TestHeuristicEstimation: """Tests for _estimate_tokens_with_heuristic.""" diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 97ee9e305..47acfded4 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1318,6 +1318,17 @@ async def test_fallback_logs_debug(self, model, openai_client, messages, caplog) assert any("native token counting failed" in record.message for record in caplog.records) + @pytest.mark.asyncio + async def test_skip_native_api_when_use_native_token_count_false(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=False) + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + # ============================================================================= # Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel From 980bc91494aff282da476bcecad64ea7359aaef6 Mon Sep 17 00:00:00 2001 From: Jack Stevenson Date: Thu, 7 May 2026 06:51:35 +1000 Subject: [PATCH 266/279] fix: correct MCPClient.__exit__ and stop() type annotations (#2248) --- src/strands/tools/mcp/mcp_client.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1884ce9bc..270012fde 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -178,7 +178,12 @@ def __enter__(self) -> "MCPClient": """ return self.start() - def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Context manager exit point that cleans up resources.""" self.stop(exc_type, exc_val, exc_tb) @@ -318,7 +323,12 @@ def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: # MCP-specific methods - def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: + def stop( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. This method is defensive and can handle partial initialization states that may occur From fc386a3e56d2b07e392720f149494c28476d152c Mon Sep 17 00:00:00 2001 From: Agent of mkmeral Date: Thu, 7 May 2026 15:57:38 -0400 Subject: [PATCH 267/279] feat(a2a): implement full A2A task lifecycle state support (#2245) Co-authored-by: agent-of-mkmeral Co-authored-by: Strands Agent Co-authored-by: agent-of-mkmeral <217235299+strands-agent@users.noreply.github.com> --- src/strands/agent/a2a_agent.py | 23 +- src/strands/multiagent/a2a/_converters.py | 60 +- src/strands/multiagent/a2a/executor.py | 132 +++- tests/strands/agent/test_a2a_agent.py | 162 ++++- .../strands/multiagent/a2a/test_converters.py | 283 ++++++++ tests/strands/multiagent/a2a/test_executor.py | 622 +++++++++++++++++- 6 files changed, 1248 insertions(+), 34 deletions(-) diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index eef47e3b4..eeb96f7a2 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -15,10 +15,14 @@ import httpx from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent +from a2a.types import AgentCard, Message, TaskArtifactUpdateEvent, TaskStatusUpdateEvent from .._async import run_async -from ..multiagent.a2a._converters import convert_input_to_message, convert_response_to_agent_result +from ..multiagent.a2a._converters import ( + _STATE_TO_STOP_REASON, + convert_input_to_message, + convert_response_to_agent_result, +) from ..types._events import AgentResultEvent from ..types.a2a import A2AResponse, A2AStreamEvent from ..types.agent import AgentInput @@ -29,6 +33,13 @@ _DEFAULT_TIMEOUT = 300 +# A2A task states that indicate the response stream is complete. +# Derived from the canonical _STATE_TO_STOP_REASON mapping in _converters. +# Terminal states (end_turn) mean no more events; input states (interrupt) mean execution is paused. +_TERMINAL_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "end_turn"} +_INPUT_STATES = {state for state, reason in _STATE_TO_STOP_REASON.items() if reason == "interrupt"} +_COMPLETE_STATES = _TERMINAL_STATES | _INPUT_STATES + class A2AAgent(AgentBase): """Client wrapper for remote A2A agents.""" @@ -265,6 +276,9 @@ async def _send_message(self, prompt: AgentInput) -> AsyncIterator[A2AResponse]: def _is_complete_event(self, event: A2AResponse) -> bool: """Check if an A2A event represents a complete response. + Recognizes all terminal states (completed, failed, canceled, rejected) + and pausing states (input_required, auth_required) as complete events. + Args: event: A2A event. @@ -289,9 +303,10 @@ def _is_complete_event(self, event: A2AResponse) -> bool: return update_event.last_chunk return False - # Status update with completed state + # Status update - check for terminal or pausing states if isinstance(update_event, TaskStatusUpdateEvent): if update_event.status and hasattr(update_event.status, "state"): - return update_event.status.state == TaskState.completed + state = update_event.status.state + return state in _COMPLETE_STATES return False diff --git a/src/strands/multiagent/a2a/_converters.py b/src/strands/multiagent/a2a/_converters.py index 22c2ffb72..7808ae325 100644 --- a/src/strands/multiagent/a2a/_converters.py +++ b/src/strands/multiagent/a2a/_converters.py @@ -4,13 +4,24 @@ from uuid import uuid4 from a2a.types import Message as A2AMessage -from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, TextPart +from a2a.types import Part, Role, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart from ...agent.agent_result import AgentResult from ...telemetry.metrics import EventLoopMetrics from ...types.a2a import A2AResponse from ...types.agent import AgentInput from ...types.content import ContentBlock, Message +from ...types.event_loop import StopReason + +# Mapping from A2A TaskState to Strands stop_reason +_STATE_TO_STOP_REASON: dict[TaskState, StopReason] = { + TaskState.completed: "end_turn", + TaskState.failed: "end_turn", + TaskState.canceled: "end_turn", + TaskState.rejected: "end_turn", + TaskState.input_required: "interrupt", + TaskState.auth_required: "interrupt", +} def convert_input_to_message(prompt: AgentInput) -> A2AMessage: @@ -79,9 +90,34 @@ def convert_content_blocks_to_parts(content_blocks: list[ContentBlock]) -> list[ return parts +def _extract_task_state(response: A2AResponse) -> TaskState | None: + """Extract the task state from an A2A response. + + Args: + response: A2A response (either A2AMessage or tuple of task and update event). + + Returns: + The TaskState if available, None otherwise. + """ + if isinstance(response, tuple) and len(response) == 2: + _task, update_event = response + if isinstance(update_event, TaskStatusUpdateEvent): + if update_event.status and hasattr(update_event.status, "state"): + return update_event.status.state + return None + + def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: """Convert A2A response to AgentResult. + Maps A2A task lifecycle states to appropriate Strands stop_reasons: + - completed → end_turn + - failed → end_turn (with error content) + - canceled → end_turn (with cancellation info) + - rejected → end_turn (with rejection info) + - input_required → interrupt (agent needs user input) + - auth_required → interrupt (agent needs authentication) + Args: response: A2A response (either A2AMessage or tuple of task and update event). @@ -89,19 +125,26 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: AgentResult with extracted content and metadata. """ content: list[ContentBlock] = [] + task_state = _extract_task_state(response) + stop_reason: StopReason = _STATE_TO_STOP_REASON.get(task_state, "end_turn") if task_state else "end_turn" if isinstance(response, tuple) and len(response) == 2: task, update_event = response # Handle artifact updates if isinstance(update_event, TaskArtifactUpdateEvent): - if update_event.artifact and hasattr(update_event.artifact, "parts"): + if update_event.artifact and hasattr(update_event.artifact, "parts") and update_event.artifact.parts: for part in update_event.artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) # Handle status updates with messages elif isinstance(update_event, TaskStatusUpdateEvent): - if update_event.status and hasattr(update_event.status, "message") and update_event.status.message: + if ( + update_event.status + and hasattr(update_event.status, "message") + and update_event.status.message + and update_event.status.message.parts + ): for part in update_event.status.message.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -109,7 +152,7 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: # Use task.artifacts when no content was extracted from the event if not content and task and hasattr(task, "artifacts") and task.artifacts is not None: for artifact in task.artifacts: - if hasattr(artifact, "parts"): + if hasattr(artifact, "parts") and artifact.parts: for part in artifact.parts: if hasattr(part, "root") and hasattr(part.root, "text"): content.append({"text": part.root.text}) @@ -123,9 +166,14 @@ def convert_response_to_agent_result(response: A2AResponse) -> AgentResult: "content": content, } + # Build state dict with A2A metadata + state: dict[str, str] = {} + if task_state is not None: + state["a2a_task_state"] = task_state.value + return AgentResult( - stop_reason="end_turn", + stop_reason=stop_reason, message=message, metrics=EventLoopMetrics(), - state={}, + state=state, ) diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index c8c00600b..7526386e8 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -8,6 +8,7 @@ streamed requests to the A2AServer. """ +import asyncio import base64 import json import logging @@ -42,7 +43,9 @@ class StrandsA2AExecutor(AgentExecutor): """Executor that adapts a Strands Agent to the A2A protocol. This executor uses streaming mode to handle the execution of agent requests - and converts Strands Agent responses to A2A protocol events. + and converts Strands Agent responses to A2A protocol events. It supports the + full A2A task lifecycle including error handling (failed state), cancellation, + and interrupt-based input_required flows. """ # Default formats for each file type when MIME type is unavailable or unrecognized @@ -75,14 +78,18 @@ async def execute( """Execute a request using the Strands Agent and send the response as A2A events. This method executes the user's input using the Strands Agent in streaming mode - and converts the agent's response to A2A events. + and converts the agent's response to A2A events. If the agent raises an exception, + the task transitions to the `failed` state. If the agent returns with interrupts, + the task transitions to the `input_required` state. Args: context: The A2A request context, containing the user's input and task metadata. event_queue: The A2A event queue used to send response events back to the client. Raises: - ServerError: If an error occurs during agent execution + ServerError: If an unrecoverable error occurs during agent execution setup + (e.g., missing input). Agent execution errors are handled gracefully + by transitioning the task to the failed state. """ task = context.current_task if not task: @@ -93,8 +100,34 @@ async def execute( try: await self._execute_streaming(context, updater) - except Exception as e: - raise ServerError(error=InternalError()) from e + except ServerError: + # Re-raise ServerErrors (setup failures like missing input) + raise + except asyncio.CancelledError: + # asyncio.CancelledError is a BaseException (not Exception) — raised when + # the asyncio task is cancelled (e.g., HTTP client disconnect, server shutdown). + # We transition to canceled state so the task doesn't remain a zombie in "working". + logger.warning("task_id=<%s> | asyncio task cancelled, transitioning to canceled state", task.id) + try: + await updater.cancel( + message=updater.new_agent_message( + parts=[Part(root=TextPart(text="Task cancelled due to connection termination"))] + ) + ) + except RuntimeError: + # Task already in terminal state + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to canceled", task.id) + raise + except Exception: + # Agent execution failures transition to failed state + logger.exception("task_id=<%s> | agent execution failed, transitioning to failed state", task.id) + try: + await updater.failed( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Agent execution failed"))]) + ) + except RuntimeError: + # Task already in terminal state (e.g., completed before error in cleanup) + logger.debug("task_id=<%s> | task already in terminal state, cannot transition to failed", task.id) async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None: """Execute request in streaming mode. @@ -105,14 +138,19 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater Args: context: The A2A request context, containing the user's input and other metadata. updater: The task updater for managing task state and sending updates. + + Raises: + ServerError: If input conversion fails (missing or empty content). """ # Convert A2A message parts to Strands ContentBlocks if context.message and hasattr(context.message, "parts"): content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts) if not content_blocks: - raise ValueError("No content blocks available") + raise ServerError( + error=InternalError(message="No valid content found in request message parts") + ) from None else: - raise ValueError("No content blocks available") + raise ServerError(error=InternalError(message="Request message is missing or has no parts")) from None if not self.enable_a2a_compliant_streaming: warnings.warn( @@ -133,8 +171,20 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater invocation_state: dict[str, Any] = {"a2a_request_context": context} try: + result: SAAgentResult | None = None async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state): - await self._handle_streaming_event(event, updater) + if "result" in event: + result = event["result"] + else: + await self._handle_streaming_event(event, updater) + + # Check if agent returned with interrupts (input_required) + # Note: stop_reason="interrupt" is the authoritative signal. Even if interrupts + # list is empty (edge case), the agent still indicated it needs input. + if result is not None and result.stop_reason == "interrupt": + await self._handle_interrupt_result(result, updater) + else: + await self._handle_agent_result(result, updater) except Exception: logger.exception("Error in streaming execution") raise @@ -143,6 +193,34 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater self._current_artifact_id = None self._is_first_chunk = True + async def _handle_interrupt_result(self, result: SAAgentResult, updater: TaskUpdater) -> None: + """Handle an agent result that contains interrupts. + + When the Strands Agent returns with stop_reason="interrupt", this maps to + the A2A `input_required` state. The interrupt details are communicated to + the client via the status message. + + Args: + result: The agent result containing interrupts. + updater: The task updater for managing task state. + """ + # Build a descriptive message about what input is needed + interrupt_descriptions = [] + for interrupt in result.interrupts or []: + desc = f"- {interrupt.name}" + if interrupt.reason: + desc += f": {interrupt.reason}" + interrupt_descriptions.append(desc) + + if interrupt_descriptions: + input_message = "Agent requires input:\n" + "\n".join(interrupt_descriptions) + else: + # Edge case: stop_reason="interrupt" but no interrupt details provided. + # Still transition to input_required — the agent signaled it needs input. + input_message = "Agent requires additional input to continue" + + await updater.requires_input(message=updater.new_agent_message(parts=[Part(root=TextPart(text=input_message))])) + async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None: """Handle a single streaming event from the Strands Agent. @@ -175,8 +253,6 @@ async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpda updater.task_id, ), ) - elif "result" in event: - await self._handle_agent_result(event["result"], updater) async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None: """Handle the final result from the Strands Agent. @@ -219,20 +295,42 @@ async def _handle_agent_result(self, result: SAAgentResult | None, updater: Task async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Cancel an ongoing execution. - This method is called when a request cancellation is requested. Currently, - cancellation is not supported by the Strands Agent executor, so this method - always raises an UnsupportedOperationError. + Transitions the task to the canceled state and attempts to stop the agent. + The agent's cancel() method is called to signal cooperative cancellation + of in-flight execution. + + Note: This transitions the A2A task state. The underlying agent execution + may still complete its current model call before stopping. Args: context: The A2A request context. event_queue: The A2A event queue. Raises: - ServerError: Always raised with an UnsupportedOperationError, as cancellation - is not currently supported. + ServerError: If no current task exists or the task is already in a terminal state. """ - logger.warning("Cancellation requested but not supported") - raise ServerError(error=UnsupportedOperationError()) + task = context.current_task + if not task: + logger.warning("context_id=<%s> | cancel requested but no current task found", context.context_id) + raise ServerError(error=UnsupportedOperationError()) from None + + # Cooperatively cancel the agent's execution (best-effort). + # Agent.cancel() is always available since self.agent is typed as Agent. + try: + self.agent.cancel() + except Exception: + logger.debug("task_id=<%s> | agent cancel signal failed (non-critical)", task.id) + + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + await updater.cancel( + message=updater.new_agent_message(parts=[Part(root=TextPart(text="Task cancelled by client request"))]) + ) + except RuntimeError: + # TaskUpdater raises RuntimeError when task is already in a terminal state + logger.warning("task_id=<%s> | cannot cancel, already in terminal state", task.id) + raise ServerError(error=UnsupportedOperationError()) from None def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]: """Classify file type based on MIME type. diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index d918033e5..9c3be7917 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -7,7 +7,7 @@ import pytest from a2a.client import ClientConfig -from a2a.types import AgentCard, Message, Part, Role, TextPart +from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart from strands.agent.a2a_agent import A2AAgent from strands.agent.agent_result import AgentResult @@ -714,3 +714,163 @@ async def mock_send_message(*args, **kwargs): # Should have 1 stream event + 1 result event (falls back to last) assert len(events) == 2 assert "result" in events[1] + + +# ========================================================================= +# NEW TESTS: Client-side lifecycle state handling +# ========================================================================= + + +def test_is_complete_event_failed_state(a2a_agent): + """Test that failed state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.failed + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_canceled_state(a2a_agent): + """Test that canceled state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.canceled + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_rejected_state(a2a_agent): + """Test that rejected state is recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.rejected + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_input_required_state(a2a_agent): + """Test that input_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.input_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_auth_required_state(a2a_agent): + """Test that auth_required state is recognized as complete (pausing).""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.auth_required + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is True + + +def test_is_complete_event_working_state_not_complete(a2a_agent): + """Test that working state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.working + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +def test_is_complete_event_submitted_state_not_complete(a2a_agent): + """Test that submitted state is NOT recognized as complete.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = TaskState.submitted + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is False + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.parametrize( + "state,expected_complete", + [ + (TaskState.completed, True), + (TaskState.failed, True), + (TaskState.canceled, True), + (TaskState.rejected, True), + (TaskState.input_required, True), + (TaskState.auth_required, True), + (TaskState.working, False), + (TaskState.submitted, False), + (TaskState.unknown, False), + ], + ids=[ + "completed-is-complete", + "failed-is-complete", + "canceled-is-complete", + "rejected-is-complete", + "input_required-is-complete", + "auth_required-is-complete", + "working-not-complete", + "submitted-not-complete", + "unknown-not-complete", + ], +) +def test_is_complete_event_all_states_parametrized(a2a_agent, state, expected_complete): + """Minor Finding 7: Parametrized test covering ALL TaskState values. + + This replaces verbose individual tests with a single parameterized test that + covers all 9 TaskState values. When a2a-sdk adds new states, adding a row here + is trivial. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskStatusUpdateEvent + + task = MagicMock() + status = MagicMock() + status.state = state + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + assert a2a_agent._is_complete_event((task, update_event)) is expected_complete diff --git a/tests/strands/multiagent/a2a/test_converters.py b/tests/strands/multiagent/a2a/test_converters.py index c3b310065..fff48653b 100644 --- a/tests/strands/multiagent/a2a/test_converters.py +++ b/tests/strands/multiagent/a2a/test_converters.py @@ -243,3 +243,286 @@ def test_convert_response_handles_missing_data(): mock_task.artifacts = [mock_artifact] result = convert_response_to_agent_result((mock_task, None)) assert len(result.message["content"]) == 0 + + +# ========================================================================= +# NEW TESTS: Lifecycle State Mapping +# ========================================================================= + + +def test_convert_response_completed_state_maps_to_end_turn(): + """Test that completed state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + + +def test_convert_response_failed_state_maps_to_end_turn(): + """Test that failed state maps to end_turn stop_reason with error content.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + # Create a status message with error info + error_part = MagicMock() + error_part.root = MagicMock() + error_part.root.text = "Agent execution failed: timeout" + + error_message = MagicMock(spec=Message) + error_message.parts = [error_part] + + status = TaskStatus(state=TaskState.failed, message=error_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "failed" + assert "Agent execution failed" in result.message["content"][0]["text"] + + +def test_convert_response_input_required_maps_to_interrupt(): + """Test that input_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + input_part = MagicMock() + input_part.root = MagicMock() + input_part.root.text = "Agent requires input:\n- approval: Need confirmation" + + input_message = MagicMock(spec=Message) + input_message.parts = [input_part] + + status = TaskStatus(state=TaskState.input_required, message=input_message) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "input-required" + assert "approval" in result.message["content"][0]["text"] + + +def test_convert_response_canceled_state_maps_to_end_turn(): + """Test that canceled state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.canceled, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "canceled" + + +def test_convert_response_rejected_state_maps_to_end_turn(): + """Test that rejected state maps to end_turn stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.rejected, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "rejected" + + +def test_convert_response_auth_required_maps_to_interrupt(): + """Test that auth_required state maps to interrupt stop_reason.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.auth_required, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "interrupt" + assert result.state.get("a2a_task_state") == "auth-required" + + +def test_extract_task_state_from_status_update(): + """Test _extract_task_state helper.""" + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + status = TaskStatus(state=TaskState.failed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + state = _extract_task_state((task, update_event)) + assert state == TaskState.failed + + +def test_extract_task_state_from_message_returns_none(): + """Test _extract_task_state returns None for Message responses.""" + from unittest.mock import MagicMock + + from a2a.types import Message + + from strands.multiagent.a2a._converters import _extract_task_state + + message = MagicMock(spec=Message) + state = _extract_task_state(message) + assert state is None + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +def test_convert_response_completed_state_includes_state_metadata(): + """Major Finding 3: The completed state test was missing state assertion. + + Every other state test asserts both stop_reason AND result.state, but the most + important one (completed — the happy path) was missing the state check. This ensures + downstream consumers relying on result.state["a2a_task_state"] won't break silently. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.completed, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "completed" # THIS WAS MISSING + + +def test_convert_response_unknown_state_defaults_to_end_turn(): + """Major Finding 4: TaskState.unknown should default to end_turn. + + The a2a-sdk has a TaskState.unknown value. Our code handles it via the .get() + default ("end_turn"). This test documents that this is an intentional design + decision: unknown states are treated as terminal completions rather than errors. + + Rationale: An unknown state from a remote server is ambiguous. Treating it as + end_turn (completed) is the safest default — the client won't hang waiting for + more events, and the result content (if any) is still accessible. + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.unknown, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + # unknown is NOT in _STATE_TO_STOP_REASON, so defaults to "end_turn" + assert result.stop_reason == "end_turn" + # state metadata should reflect the actual state value + assert result.state.get("a2a_task_state") == "unknown" + + +def test_convert_response_working_state_defaults_to_end_turn(): + """Test that working state (not in mapping) defaults to end_turn. + + This covers the edge case where a TaskStatusUpdateEvent with state=working + somehow reaches the converter (shouldn't normally happen since _is_complete_event + filters these out, but defense-in-depth). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskState, TaskStatus, TaskStatusUpdateEvent + + task = MagicMock() + task.artifacts = None + + status = TaskStatus(state=TaskState.working, message=None) + update_event = MagicMock(spec=TaskStatusUpdateEvent) + update_event.status = status + + result = convert_response_to_agent_result((task, update_event)) + assert result.stop_reason == "end_turn" + assert result.state.get("a2a_task_state") == "working" + + +def test_extract_task_state_from_artifact_update_returns_none(): + """Minor Finding 5: _extract_task_state with TaskArtifactUpdateEvent returns None. + + This is the untested path where the update event is an artifact (not status). + """ + from unittest.mock import MagicMock + + from a2a.types import TaskArtifactUpdateEvent + + from strands.multiagent.a2a._converters import _extract_task_state + + task = MagicMock() + mock_event = MagicMock(spec=TaskArtifactUpdateEvent) + + state = _extract_task_state((task, mock_event)) + assert state is None + + +def test_state_to_stop_reason_covers_all_lifecycle_states(): + """Verify _STATE_TO_STOP_REASON has mappings for all documented lifecycle states. + + Guards against future additions to the a2a-sdk that we miss. + """ + from a2a.types import TaskState + + from strands.multiagent.a2a._converters import _STATE_TO_STOP_REASON + + # These are the states we explicitly handle + expected_mapped = { + TaskState.completed, + TaskState.failed, + TaskState.canceled, + TaskState.rejected, + TaskState.input_required, + TaskState.auth_required, + } + assert set(_STATE_TO_STOP_REASON.keys()) == expected_mapped + + # These should NOT be in the mapping (they're non-terminal progress states) + assert TaskState.working not in _STATE_TO_STOP_REASON + assert TaskState.submitted not in _STATE_TO_STOP_REASON + assert TaskState.unknown not in _STATE_TO_STOP_REASON diff --git a/tests/strands/multiagent/a2a/test_executor.py b/tests/strands/multiagent/a2a/test_executor.py index dc90fbdd6..940d26f8c 100644 --- a/tests/strands/multiagent/a2a/test_executor.py +++ b/tests/strands/multiagent/a2a/test_executor.py @@ -583,7 +583,7 @@ async def mock_stream(content_blocks): async def test_execute_streaming_mode_handles_agent_exception( mock_strands_agent, mock_request_context, mock_event_queue ): - """Test that execute handles agent exceptions correctly in streaming mode.""" + """Test that execute transitions to failed state when agent raises exception.""" # Setup mock agent to raise exception when stream_async is called mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) @@ -608,18 +608,25 @@ async def test_execute_streaming_mode_handles_agent_exception( mock_message.parts = [part] mock_request_context.message = mock_message - with pytest.raises(ServerError): - await executor.execute(mock_request_context, mock_event_queue) + # Should NOT raise - instead transitions to failed state + await executor.execute(mock_request_context, mock_event_queue) # Verify agent was called mock_strands_agent.stream_async.assert_called_once() + # Verify a failed status event was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + from a2a.types import TaskState, TaskStatusUpdateEvent -@pytest.mark.asyncio -async def test_cancel_raises_unsupported_operation_error(mock_strands_agent, mock_request_context, mock_event_queue): - """Test that cancel raises UnsupportedOperationError.""" + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text executor = StrandsA2AExecutor(mock_strands_agent) + # Cancel with no current_task raises UnsupportedOperationError + mock_request_context.current_task = None with pytest.raises(ServerError) as excinfo: await executor.cancel(mock_request_context, mock_event_queue) @@ -1331,3 +1338,606 @@ async def test_invocation_state_with_a2a_compliant_streaming( assert invocation_state is not None assert invocation_state["a2a_request_context"] is mock_request_context + + +# ========================================================================= +# NEW TESTS: A2A Lifecycle State Support +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_transitions_to_failed_on_streaming_error( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that errors during streaming transition task to failed state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that raises mid-stream.""" + yield {"data": "partial output"} + raise RuntimeError("Connection lost") + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-fail" + mock_task.context_id = "ctx-fail" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Should not raise + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + failed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.failed + ] + assert len(failed_events) == 1 + assert "Agent execution failed" in failed_events[0].status.message.parts[0].root.text + + +@pytest.mark.asyncio +async def test_cancel_with_valid_task(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel transitions task to canceled state when task exists.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel" + mock_task.context_id = "ctx-cancel" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify canceled state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_cancel_without_task_raises_unsupported(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel raises UnsupportedOperationError when no task exists.""" + executor = StrandsA2AExecutor(mock_strands_agent) + mock_request_context.current_task = None + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test that agent interrupts map to input_required state.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + # Create a mock result with interrupts + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_interrupt = Interrupt(id="int-1", name="approval", reason="Need user approval") + mock_result.interrupts = [mock_interrupt] + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Processing..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-interrupt" + mock_task.context_id = "ctx-interrupt" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete file X" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify input_required state was enqueued + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "approval" in msg_text + assert "Need user approval" in msg_text + + +@pytest.mark.asyncio +async def test_execute_with_multiple_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test handling of multiple interrupts in a single result.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + from strands.interrupt import Interrupt + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [ + Interrupt(id="int-1", name="confirm_delete", reason="Confirm deletion of file X"), + Interrupt(id="int-2", name="select_backup", reason="Choose backup location"), + ] + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-multi-int" + mock_task.context_id = "ctx-multi-int" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "delete with backup" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + msg_text = input_required_events[0].status.message.parts[0].root.text + assert "confirm_delete" in msg_text + assert "select_backup" in msg_text + assert "Confirm deletion of file X" in msg_text + assert "Choose backup location" in msg_text + + +@pytest.mark.asyncio +async def test_execute_normal_completion_no_interrupts(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that normal completion (no interrupts) still works as before.""" + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "end_turn" + mock_result.interrupts = None + mock_result.__str__ = MagicMock(return_value="Task completed successfully") + + async def mock_stream(content_blocks, **kwargs): + yield {"data": "Working..."} + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-normal" + mock_task.context_id = "ctx-normal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Verify completed state was enqueued (not input_required) + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + assert len(completed_events) == 1 + + # Verify no input_required events + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 0 + + +@pytest.mark.asyncio +async def test_execute_setup_failure_raises_server_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that setup failures (missing message) still raise ServerError.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-setup-fail" + mock_task.context_id = "ctx-setup-fail" + mock_request_context.current_task = mock_task + + # No message at all + mock_request_context.message = None + + with pytest.raises(ServerError) as excinfo: + await executor.execute(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, InternalError) + + +@pytest.mark.asyncio +async def test_execute_error_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that error during execution is handled gracefully when task is already in terminal state.""" + from a2a.types import TextPart + + # Make stream_async raise to trigger the error path + mock_strands_agent.stream_async = MagicMock(side_effect=Exception("Agent error")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-already-done" + mock_task.context_id = "ctx-already-done" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater.failed to raise RuntimeError (simulating task already in terminal state) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.failed = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + # Should NOT raise - handles RuntimeError gracefully + await executor.execute(mock_request_context, mock_event_queue) + + # Verify failed() was attempted + mock_updater.failed.assert_called_once() + + +@pytest.mark.asyncio +async def test_cancel_calls_agent_cancel_method(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() attempts to call agent.cancel() if available.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method + mock_strands_agent.cancel = MagicMock() + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-agent" + mock_task.context_id = "ctx-cancel-agent" + mock_request_context.current_task = mock_task + + await executor.cancel(mock_request_context, mock_event_queue) + + # Verify agent.cancel() was called + mock_strands_agent.cancel.assert_called_once() + + # Verify task state is canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_handles_agent_cancel_exception(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() gracefully handles agent.cancel() raising an exception.""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Give the agent a cancel method that raises + mock_strands_agent.cancel = MagicMock(side_effect=RuntimeError("Cannot cancel")) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancel-err" + mock_task.context_id = "ctx-cancel-err" + mock_request_context.current_task = mock_task + + # Should still succeed (agent cancel is best-effort) + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_raises_when_task_already_terminal(mock_strands_agent, mock_request_context, mock_event_queue): + """Test that cancel() raises ServerError when task is already in a terminal state.""" + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-terminal" + mock_task.context_id = "ctx-terminal" + mock_request_context.current_task = mock_task + + # Patch TaskUpdater.cancel to raise RuntimeError (task already completed/failed) + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + MockTaskUpdater.return_value = mock_updater + + with pytest.raises(ServerError) as excinfo: + await executor.cancel(mock_request_context, mock_event_queue) + + assert isinstance(excinfo.value.error, UnsupportedOperationError) + mock_updater.cancel.assert_called_once() + + +# ========================================================================= +# DEVIL'S ADVOCATE FINDINGS — Tests addressing review gaps +# ========================================================================= + + +@pytest.mark.asyncio +async def test_execute_handles_asyncio_cancelled_error(mock_strands_agent, mock_request_context, mock_event_queue): + """Critical Finding 1: asyncio.CancelledError transitions task to canceled state. + + asyncio.CancelledError is a BaseException (not Exception). It's raised when an asyncio + task is cancelled — e.g., HTTP client disconnect, server shutdown, task group cancellation. + Without explicit handling, the task would remain stuck in 'working' state forever (zombie). + + This test verifies the task transitions to 'canceled' before re-raising CancelledError. + """ + import asyncio + + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + async def mock_stream(content_blocks, **kwargs): + """Mock streaming that gets cancelled mid-stream.""" + yield {"data": "partial output"} + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled" + mock_task.context_id = "ctx-cancelled" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # CancelledError should be re-raised (framework needs to know task was cancelled) + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # But BEFORE re-raising, the task should have been transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 + assert ( + "cancelled" in canceled_events[0].status.message.parts[0].root.text.lower() + or "connection termination" in canceled_events[0].status.message.parts[0].root.text.lower() + ) + + +@pytest.mark.asyncio +async def test_execute_asyncio_cancelled_when_task_already_terminal( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Test CancelledError handling when task is already in a terminal state. + + If the task completed right before the cancellation arrives, the updater.cancel() + will raise RuntimeError. We should handle this gracefully and still re-raise CancelledError. + """ + import asyncio + + from a2a.types import TextPart + + async def mock_stream(content_blocks, **kwargs): + """Async generator that immediately raises CancelledError.""" + yield {"data": "partial"} # Must yield to be async generator + raise asyncio.CancelledError() + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-cancelled-terminal" + mock_task.context_id = "ctx-cancelled-terminal" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "test" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + # Patch TaskUpdater to simulate task already in terminal state + with patch("strands.multiagent.a2a.executor.TaskUpdater") as MockTaskUpdater: + mock_updater = MagicMock() + mock_updater.cancel = AsyncMock(side_effect=RuntimeError("Task is already in a terminal state")) + mock_updater.update_status = AsyncMock() + mock_updater.add_artifact = AsyncMock() + mock_updater.new_agent_message = MagicMock(return_value=MagicMock()) + mock_updater.context_id = "ctx-cancelled-terminal" + mock_updater.task_id = "task-cancelled-terminal" + MockTaskUpdater.return_value = mock_updater + + # Should still re-raise CancelledError + with pytest.raises(asyncio.CancelledError): + await executor.execute(mock_request_context, mock_event_queue) + + # cancel() was attempted + mock_updater.cancel.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_empty_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Critical Finding 2: stop_reason='interrupt' with empty interrupts list. + + The agent explicitly signaled it needs input (stop_reason="interrupt") but provided + no interrupt details. This should STILL transition to input_required — the stop_reason + is the authoritative signal. Previously this would silently complete the task. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = [] # Empty list — previously this was falsy and caused completion! + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-empty-interrupts" + mock_task.context_id = "ctx-empty-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + # Should transition to input_required, NOT completed + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + completed_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.completed + ] + + assert len(input_required_events) == 1, "Empty interrupts list should still trigger input_required" + assert len(completed_events) == 0, "Should NOT complete when stop_reason='interrupt'" + # Verify the fallback message is used + assert "additional input" in input_required_events[0].status.message.parts[0].root.text.lower() + + +@pytest.mark.asyncio +async def test_execute_with_interrupt_none_list_transitions_to_input_required( + mock_strands_agent, mock_request_context, mock_event_queue +): + """Edge case: stop_reason='interrupt' with interrupts=None. + + Same logic — the stop_reason is authoritative. None interrupts should + still result in input_required transition. + """ + from a2a.types import TaskState, TaskStatusUpdateEvent, TextPart + + mock_result = MagicMock(spec=SAAgentResult) + mock_result.stop_reason = "interrupt" + mock_result.interrupts = None # None, not empty list + + async def mock_stream(content_blocks, **kwargs): + yield {"result": mock_result} + + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-none-interrupts" + mock_task.context_id = "ctx-none-interrupts" + mock_request_context.current_task = mock_task + + mock_text_part = MagicMock(spec=TextPart) + mock_text_part.text = "do something" + mock_part = MagicMock() + mock_part.root = mock_text_part + mock_message = MagicMock() + mock_message.parts = [mock_part] + mock_request_context.message = mock_message + + await executor.execute(mock_request_context, mock_event_queue) + + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + input_required_events = [ + e + for e in enqueued_events + if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.input_required + ] + assert len(input_required_events) == 1 + + +@pytest.mark.asyncio +async def test_cancel_without_hasattr_cancel(mock_strands_agent, mock_request_context, mock_event_queue): + """Test cancel works when agent doesn't have cancel() method (AttributeError).""" + from a2a.types import TaskState, TaskStatusUpdateEvent + + # Remove cancel method entirely + del mock_strands_agent.cancel + + executor = StrandsA2AExecutor(mock_strands_agent) + + mock_task = MagicMock() + mock_task.id = "task-no-cancel-method" + mock_task.context_id = "ctx-no-cancel-method" + mock_request_context.current_task = mock_task + + # Should succeed — AttributeError from agent.cancel() is caught + await executor.cancel(mock_request_context, mock_event_queue) + + # Task should still be transitioned to canceled + enqueued_events = [call[0][0] for call in mock_event_queue.enqueue_event.call_args_list] + canceled_events = [ + e for e in enqueued_events if isinstance(e, TaskStatusUpdateEvent) and e.status.state == TaskState.canceled + ] + assert len(canceled_events) == 1 From ead3179b8e661657b30f07813d0e18cf2f79f5b7 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Fri, 8 May 2026 08:56:24 -0400 Subject: [PATCH 268/279] fix: integration test updates (#2262) --- tests_integ/models/test_model_bedrock.py | 37 +++++++----------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 73d67f414..06c72ef88 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -276,29 +276,6 @@ def test_structured_output_multi_modal_input(streaming_agent, yellow_img, yellow assert tru_color == exp_color -def test_redacted_content_handling(): - """Test redactedContent handling with thinking mode.""" - bedrock_model = BedrockModel( - model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", - additional_request_fields={ - "thinking": { - "type": "enabled", - "budget_tokens": 2000, - } - }, - ) - - agent = Agent(name="test_redact", model=bedrock_model) - # https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#example-working-with-redacted-thinking-blocks - result = agent( - "ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB" - ) - - assert "reasoningContent" in result.message["content"][0] - assert "redactedContent" in result.message["content"][0]["reasoningContent"] - assert isinstance(result.message["content"][0]["reasoningContent"]["redactedContent"], bytes) - - def test_reasoning_content_in_messages_with_thinking_disabled(): """Test that messages with reasoningContent are accepted when thinking is explicitly disabled.""" # First, get a real reasoning response with thinking enabled @@ -489,14 +466,22 @@ def test_prompt_caching_with_ttl_in_messages(): ) -def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model): +def test_prompt_caching_backward_compatibility_no_ttl(): """Test that prompt caching works without TTL (backward compatibility). Verifies that cache points work correctly when TTL is not specified, maintaining backward compatibility with existing code. + + Uses Claude Haiku 4.5 which supports prompt caching on Bedrock. + Minimum 4096 tokens required for caching with Haiku 4.5. """ + model = BedrockModel( + model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", + streaming=False, + ) + unique_id = str(uuid.uuid4()) - large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 200) + large_context = f"Background information for test {unique_id}: " + ("This is important context. " * 1000) system_prompt_with_cache = [ {"text": large_context}, @@ -505,7 +490,7 @@ def test_prompt_caching_backward_compatibility_no_ttl(non_streaming_model): ] agent = Agent( - model=non_streaming_model, + model=model, system_prompt=system_prompt_with_cache, load_tools_from_directory=False, ) From f8621853d3bbd1e69b59d3871f4ca363fdcec0e0 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Fri, 8 May 2026 10:02:44 -0400 Subject: [PATCH 269/279] feat: add proactive context compression to conversation managers (#2239) Co-authored-by: agent-of-mkmeral --- .../agent/conversation_manager/__init__.py | 4 +- .../conversation_manager.py | 144 +++++++-- .../null_conversation_manager.py | 11 +- .../sliding_window_conversation_manager.py | 43 ++- .../summarizing_conversation_manager.py | 92 ++++-- .../agent/test_conversation_manager.py | 297 +++++++++++++++++- .../test_summarizing_conversation_manager.py | 107 ++++++- 7 files changed, 605 insertions(+), 93 deletions(-) diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..9f6d54ff9 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -3,6 +3,7 @@ It includes: - ConversationManager: Abstract base class defining the conversation management interface +- ProactiveCompressionConfig: Configuration type for proactive compression settings - NullConversationManager: A no-op implementation that does not modify conversation history - SlidingWindowConversationManager: An implementation that maintains a sliding window of messages to control context size while preserving conversation coherence @@ -13,7 +14,7 @@ is critical for effective agent interactions. """ -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager @@ -21,6 +22,7 @@ __all__ = [ "ConversationManager", "NullConversationManager", + "ProactiveCompressionConfig", "SlidingWindowConversationManager", "SummarizingConversationManager", ] diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 690ecbde5..7e2283883 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,14 +1,33 @@ """Abstract interface for conversation history management.""" +import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict, Union +from ...hooks.events import BeforeModelCallEvent from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent +logger = logging.getLogger(__name__) + +DEFAULT_COMPRESSION_THRESHOLD = 0.7 +DEFAULT_CONTEXT_WINDOW_LIMIT = 200_000 + + +class ProactiveCompressionConfig(TypedDict, total=False): + """Configuration for proactive compression when passed as an object. + + Attributes: + compression_threshold: Ratio of context window usage that triggers proactive compression. + Value between 0 (exclusive) and 1 (inclusive). + Defaults to 0.7 (compress when 70% of the context window is used). + """ + + compression_threshold: float + class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. @@ -22,45 +41,122 @@ class ConversationManager(ABC, HookProvider): ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper - hook registration. + hook registration chain. + + The primary responsibility of a ConversationManager is overflow recovery: when the model encounters a context + window overflow, :meth:`reduce_context` is called with ``e`` set and MUST reduce the history enough for the next + model call to succeed. + + Subclasses can enable proactive compression by passing ``proactive_compression`` in the constructor. + When enabled, the base class registers a ``BeforeModelCallEvent`` hook that checks projected input tokens + against the model's context window limit and calls :meth:`reduce_context` (without ``e``) when the + threshold is exceeded. This is a best-effort operation — errors are swallowed so the model call can + still proceed. Example: ```python - class MyConversationManager(ConversationManager): - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - super().register_hooks(registry, **kwargs) - # Register additional hooks here + # Enable proactive compression with default threshold (0.7) + SlidingWindowConversationManager(window_size=50, proactive_compression=True) + + # Enable proactive compression with custom threshold + SummarizingConversationManager(proactive_compression={"compression_threshold": 0.8}) ``` """ - def __init__(self) -> None: + def __init__(self, *, proactive_compression: Union[bool, "ProactiveCompressionConfig", None] = None) -> None: """Initialize the ConversationManager. + Args: + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. + + Raises: + ValueError: If compression_threshold is not in the valid range (0, 1]. + Attributes: removed_message_count: The messages that have been removed from the agents messages array. These represent messages provided by the user or LLM that have been removed, not messages included by the conversation manager through something like summarization. """ + # Resolve the threshold from proactive_compression parameter + if proactive_compression is True: + threshold: float | None = DEFAULT_COMPRESSION_THRESHOLD + elif isinstance(proactive_compression, dict): + threshold = proactive_compression.get("compression_threshold", DEFAULT_COMPRESSION_THRESHOLD) + else: + threshold = None + + if threshold is not None and (threshold <= 0 or threshold > 1): + raise ValueError( + f"compression_threshold must be between 0 (exclusive) and 1 (inclusive), got {threshold}" + ) + self.removed_message_count = 0 + self._compression_threshold = threshold + self._context_window_limit_warned = False def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for agent lifecycle events. + Always registers a ``BeforeModelCallEvent`` hook for proactive compression. + When ``proactive_compression`` is not configured, the handler is a no-op (early return). + Derived classes that override this method must call the base implementation to ensure proper hook registration chain. Args: registry: The hook registry to register callbacks with. **kwargs: Additional keyword arguments for future extensibility. + """ + # Always subscribe — the threshold check happens inside the handler + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call_threshold) - Example: - ```python - def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: - super().register_hooks(registry, **kwargs) - registry.add_callback(SomeEvent, self.on_some_event) - ``` + def _on_before_model_call_threshold(self, event: BeforeModelCallEvent) -> None: + """Handle BeforeModelCallEvent for proactive compression. + + When proactive compression is not configured, this is a no-op. + When configured, checks projected input tokens against the context window limit + and calls reduce_context() without error (best-effort) when threshold is exceeded. + + Args: + event: The before model call event. """ - pass + # Early return if proactive compression is not enabled + if self._compression_threshold is None: + return + + context_window_limit = event.agent.model.context_window_limit + if context_window_limit is None: + context_window_limit = DEFAULT_CONTEXT_WINDOW_LIMIT + if not self._context_window_limit_warned: + self._context_window_limit_warned = True + logger.warning( + "context_window_limit=<%s> | context_window_limit not set on model, using default." + " Set context_window_limit in your model config for accurate proactive compression", + DEFAULT_CONTEXT_WINDOW_LIMIT, + ) + + if event.projected_input_tokens is None: + logger.debug("projected_input_tokens= | skipping proactive compression") + return + + ratio = event.projected_input_tokens / context_window_limit + if ratio >= self._compression_threshold: + logger.debug( + "projected_tokens=<%s>, limit=<%s>, ratio=<%.2f>, compression_threshold=<%s>" + " | compression threshold exceeded, reducing context", + event.projected_input_tokens, + context_window_limit, + ratio, + self._compression_threshold, + ) + # Proactive compression is best-effort: swallow errors so the model call can still proceed. + try: + self.reduce_context(agent=event.agent) + except Exception: + logger.debug("proactive compression failed, will proceed with model call", exc_info=True) def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. @@ -99,22 +195,24 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: @abstractmethod def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: - """Called when the model's context window is exceeded. - - This method should implement the specific strategy for reducing the window size when a context overflow occurs. - It is typically called after a ContextWindowOverflowException is caught. + """Reduce the conversation history. - Implementations might use strategies such as: + Called in two scenarios: + 1. **Reactive** (e is set): A context window overflow occurred. The implementation + MUST remove enough history for the next model call to succeed, or re-raise the error. + 2. **Proactive** (e is None): The compression threshold was exceeded. This is best-effort — + returning without reduction or raising is acceptable; the model call proceeds regardless. - - Removing the N oldest messages - - Summarizing older context - - Applying importance-based filtering - - Maintaining critical conversation markers + Implementations should modify ``agent.messages`` in-place. Args: agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call — the implementation MUST + reduce enough history for the next model call to succeed. + When None, this is a proactive compression call — best-effort reduction to avoid + hitting the context window limit. **kwargs: Additional keyword arguments for future extensibility. """ pass diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 11632525d..4077cb08b 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from ...agent.agent import Agent -from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -29,7 +28,10 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: - """Does not reduce context and raises an exception. + """Does not reduce context. + + When called reactively (e is not None), re-raises the overflow exception since this + manager cannot reduce context. When called proactively (e is None), returns silently. Args: agent: The agent whose conversation history will remain unmodified. @@ -37,10 +39,7 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A **kwargs: Additional keyword arguments for future extensibility. Raises: - e: If provided. - ContextWindowOverflowException: If e is None. + e: If provided (reactive overflow). """ if e: raise e - else: - raise ContextWindowOverflowException("Context window overflowed!") diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 1b45dd42c..1ad8edc24 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -10,7 +10,7 @@ from ...types.content import ContentBlock, Messages from ...types.exceptions import ContextWindowOverflowException from ...types.tools import ToolResultContent -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig logger = logging.getLogger(__name__) @@ -37,6 +37,7 @@ def __init__( should_truncate_results: bool = True, *, per_turn: bool | int = False, + proactive_compression: bool | ProactiveCompressionConfig | None = None, ): """Initialize the sliding window conversation manager. @@ -54,6 +55,10 @@ def __init__( manage message history and prevent the agent loop from slowing down. Start with per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed for performance tuning. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. Raises: ValueError: If window_size is negative, or if per_turn is 0 or a negative integer. @@ -63,7 +68,7 @@ def __init__( if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0: raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}") - super().__init__() + super().__init__(proactive_compression=proactive_compression) self.window_size = window_size self.should_truncate_results = should_truncate_results @@ -158,6 +163,12 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. + When ``e`` is set (reactive overflow recovery), attempts to truncate large tool results + first before falling back to message trimming. + + When ``e`` is None (proactive compression or routine management), only trims messages + without attempting tool result truncation. + The method handles special cases where trimming the messages leads to: - toolResult with no corresponding toolUse - toolUse with no corresponding toolResult @@ -166,12 +177,14 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive or routine management call. **kwargs: Additional keyword arguments for future extensibility. Raises: ContextWindowOverflowException: If the context cannot be reduced further and a context overflow - error was provided (e is not None). When called during routine window management (e is None), - logs a warning and returns without modification. + error was provided (e is not None). When called during routine window management or + proactive compression (e is None), logs a warning and returns without modification. """ messages = agent.messages @@ -181,16 +194,18 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A messages[:] = [] return - # Try to truncate the tool result first - oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) - if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: - logger.debug( - "message_index=<%s> | found message with tool results at index", oldest_message_idx_with_tool_results - ) - results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) - if results_truncated: - logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) - return + # Try to truncate the tool result first (only for reactive overflow, not proactive compression) + if e is not None: + oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages) + if oldest_message_idx_with_tool_results is not None and self.should_truncate_results: + logger.debug( + "message_index=<%s> | found message with tool results at index", + oldest_message_idx_with_tool_results, + ) + results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results) + if results_truncated: + logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results) + return # Try to trim index id when tool result cannot be truncated anymore # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index abd4d08b5..2030e1d3b 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -12,7 +12,7 @@ from ...types.content import Message from ...types.exceptions import ContextWindowOverflowException from ...types.tools import AgentTool -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ProactiveCompressionConfig if TYPE_CHECKING: from ..agent import Agent @@ -65,6 +65,8 @@ def __init__( preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, summarization_system_prompt: str | None = None, + *, + proactive_compression: bool | ProactiveCompressionConfig | None = None, ): """Initialize the summarizing conversation manager. @@ -77,8 +79,12 @@ def __init__( If provided, this agent can use tools as part of the summarization process. summarization_system_prompt: Optional system prompt override for summarization. If None, uses the default summarization prompt. + proactive_compression: Enable proactive context compression before the model call. + - ``True``: compress when 70% of the context window is used (default threshold). + - ``{"compression_threshold": float}``: compress at the specified ratio (0, 1]. + - ``False`` or ``None``: disabled, only reactive overflow recovery is used. """ - super().__init__() + super().__init__(proactive_compression=proactive_compression) if summarization_agent is not None and summarization_system_prompt is not None: raise ValueError( "Cannot provide both summarization_agent and summarization_system_prompt. " @@ -126,54 +132,76 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. + When ``e`` is set (reactive overflow recovery), summarization failure is re-raised — + the agent loop must not proceed with an overflow. + + When ``e`` is None (proactive compression), summarization failure is logged and + returns silently — the model call proceeds regardless. + Args: agent: The agent whose conversation history will be reduced. The agent's messages list is modified in-place. e: The exception that triggered the context reduction, if any. + When set, this is a reactive overflow recovery call. + When None, this is a proactive compression call (best-effort). **kwargs: Additional keyword arguments for future extensibility. Raises: - ContextWindowOverflowException: If the context cannot be summarized. + Exception: If summarization fails during reactive overflow recovery (e is set). """ try: - # Calculate how many messages to summarize - messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + self._summarize_oldest(agent) + except Exception as summarization_error: + if e is not None: + # Reactive: rethrow so the ContextWindowOverflowException propagates + logger.error("Summarization failed: %s", summarization_error) + raise summarization_error from e + # Proactive: best-effort, swallow errors so the model call can still proceed. + logger.warning("Proactive summarization failed, continuing: %s", summarization_error) - # Ensure we don't summarize recent messages - messages_to_summarize_count = min( - messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages - ) + def _summarize_oldest(self, agent: "Agent") -> None: + """Summarize the oldest messages and replace them with a summary. - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + Args: + agent: The agent instance. - # Adjust split point to avoid breaking ToolUse/ToolResult pairs - messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( - agent.messages, messages_to_summarize_count - ) + Raises: + ContextWindowOverflowException: If there are insufficient messages for summarization. + """ + # Calculate how many messages to summarize + messages_to_summarize_count = max(1, int(len(agent.messages) * self.summary_ratio)) + + # Ensure we don't summarize recent messages + messages_to_summarize_count = min( + messages_to_summarize_count, len(agent.messages) - self.preserve_recent_messages + ) - if messages_to_summarize_count <= 0: - raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Extract messages to summarize - messages_to_summarize = agent.messages[:messages_to_summarize_count] - remaining_messages = agent.messages[messages_to_summarize_count:] + # Adjust split point to avoid breaking ToolUse/ToolResult pairs + messages_to_summarize_count = self._adjust_split_point_for_tool_pairs( + agent.messages, messages_to_summarize_count + ) - # Keep track of the number of messages that have been summarized thus far. - self.removed_message_count += len(messages_to_summarize) - # If there is a summary message, don't count it in the removed_message_count. - if self._summary_message: - self.removed_message_count -= 1 + if messages_to_summarize_count <= 0: + raise ContextWindowOverflowException("Cannot summarize: insufficient messages for summarization") - # Generate summary - self._summary_message = self._generate_summary(messages_to_summarize, agent) + # Extract messages to summarize + messages_to_summarize = agent.messages[:messages_to_summarize_count] + remaining_messages = agent.messages[messages_to_summarize_count:] - # Replace the summarized messages with the summary - agent.messages[:] = [self._summary_message] + remaining_messages + # Keep track of the number of messages that have been summarized thus far. + self.removed_message_count += len(messages_to_summarize) + # If there is a summary message, don't count it in the removed_message_count. + if self._summary_message: + self.removed_message_count -= 1 - except Exception as summarization_error: - logger.error("Summarization failed: %s", summarization_error) - raise summarization_error from e + # Generate summary + self._summary_message = self._generate_summary(messages_to_summarize, agent) + + # Replace the summarized messages with the summary + agent.messages[:] = [self._summary_message] + remaining_messages def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 8679e6fd7..df748241e 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -4,6 +4,7 @@ from strands import tool from strands.agent.agent import Agent +from strands.agent.conversation_manager.conversation_manager import ConversationManager from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.hooks.events import BeforeModelCallEvent @@ -300,7 +301,7 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): ] test_agent = Agent(messages=messages) - manager.reduce_context(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) result_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] assert result_text.startswith("A" * 200) @@ -310,8 +311,35 @@ def test_sliding_window_conversation_manager_with_tool_results_truncated(): assert messages[1]["content"][0]["toolResult"]["status"] == "success" -def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): - """Test that NullConversationManager doesn't modify messages.""" +def test_sliding_window_proactive_compression_skips_tool_result_truncation(): + """Proactive compression (e=None) should only trim messages, not truncate tool results.""" + large_text = "A" * 300 + "B" * 300 + "C" * 300 + manager = SlidingWindowConversationManager(window_size=2) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "789", "content": [{"text": large_text}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "Done"}]}, + {"role": "user", "content": [{"text": "Next question"}]}, + ] + test_agent = Agent(messages=messages) + + manager.reduce_context(test_agent) # e=None (proactive) + + # Tool results should NOT be truncated during proactive compression + for msg in messages: + for content in msg.get("content", []): + if "toolResult" in content: + for item in content["toolResult"].get("content", []): + if "text" in item: + assert "... [truncated:" not in item["text"] + + +def test_null_conversation_manager_reduce_context_proactive_returns_silently(): + """Proactive compression (e=None) returns silently without raising.""" manager = NullConversationManager() messages = [ {"role": "user", "content": [{"text": "Hello"}]}, @@ -322,12 +350,25 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow manager.apply_management(test_agent) - with pytest.raises(ContextWindowOverflowException): - manager.reduce_context(messages) + # Proactive call (e=None) should not raise + manager.reduce_context(test_agent) assert messages == original_messages +def test_null_conversation_manager_reduce_context_reactive_raises_overflow(): + """Reactive overflow (e is not None) re-raises the exception.""" + manager = NullConversationManager() + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there"}]}, + ] + test_agent = Agent(messages=messages) + + with pytest.raises(ContextWindowOverflowException): + manager.reduce_context(test_agent, e=ContextWindowOverflowException("overflow")) + + def test_null_conversation_manager_reduce_context_with_exception_raises_same_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = NullConversationManager() @@ -400,9 +441,10 @@ def reduce_context(self, agent, e=None, **kwargs): manager = MinimalConversationManager() registry = HookRegistry() - # Should work without error + # Should work without error — the base class always registers the hook manager.register_hooks(registry) - assert not registry.has_callbacks() + # Base class always registers the proactive compression hook + assert registry.has_callbacks() def test_per_turn_hooks_registration(): @@ -555,7 +597,7 @@ def test_truncation_targets_oldest_message_first(): ] test_agent = Agent(messages=messages) - manager.reduce_context(test_agent) + manager.reduce_context(test_agent, e=RuntimeError("context overflow")) # The oldest tool result (index 1) must be truncated oldest_text = messages[1]["content"][0]["toolResult"]["content"][0]["text"] @@ -755,3 +797,242 @@ def test_window_size_zero_clears_on_overflow(): manager.reduce_context(test_agent, e=Exception("overflow")) assert messages == [] + + +# ============================================================================== +# Proactive Compression Tests (proactive_compression parameter) +# ============================================================================== + + +class _MinimalManager(ConversationManager): + """Manager that only implements abstract methods.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.reduce_context_call_count = 0 + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + self.reduce_context_call_count += 1 + if agent.messages: + agent.messages.pop(0) + + +def _make_mock_agent(messages=None, context_window_limit=1000): + agent = MagicMock() + agent.messages = messages if messages is not None else [] + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + return agent + + +def _make_threshold_event(agent, projected_input_tokens=None): + return BeforeModelCallEvent( + agent=agent, + invocation_state={}, + projected_input_tokens=projected_input_tokens, + ) + + +def test_proactive_compression_rejects_zero(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 0}) + + +def test_proactive_compression_rejects_negative(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": -0.5}) + + +def test_proactive_compression_rejects_greater_than_one(): + with pytest.raises(ValueError, match="compression_threshold must be between 0"): + _MinimalManager(proactive_compression={"compression_threshold": 1.5}) + + +def test_proactive_compression_accepts_exactly_one(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 1.0}) + assert manager._compression_threshold == 1.0 + + +def test_proactive_compression_none_by_default(): + manager = _MinimalManager() + assert manager._compression_threshold is None + + +def test_proactive_compression_true_uses_default_threshold(): + """proactive_compression=True uses default threshold of 0.7.""" + manager = _MinimalManager(proactive_compression=True) + assert manager._compression_threshold == 0.7 + + +def test_proactive_compression_false_disables(): + """proactive_compression=False means no compression.""" + manager = _MinimalManager(proactive_compression=False) + assert manager._compression_threshold is None + + +def test_proactive_compression_always_registers_hook(): + """Hook is always registered regardless of proactive_compression setting.""" + manager = _MinimalManager() + registry = HookRegistry() + manager.register_hooks(registry) + # Always registers the hook + assert registry.has_callbacks() + + +def test_proactive_compression_hook_is_noop_when_not_configured(): + """BeforeModelCallEvent handler is a no-op when proactive_compression is not set.""" + manager = _MinimalManager() + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=900) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_calls_reduce_context_when_exceeded(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_no_call_when_below(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_no_call_when_projected_tokens_none(): + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=None) + registry.invoke_callbacks(event) + + assert manager.reduce_context_call_count == 0 + + +def test_proactive_compression_uses_default_when_context_window_limit_not_set(): + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + # projected_input_tokens=150_000 is 75% of the 200k default, exceeding 0.7 threshold + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + mock_logger.warning.assert_called_once() + assert "using default" in mock_logger.warning.call_args[0][0] + + assert manager.reduce_context_call_count == 1 + + +def test_proactive_compression_warns_only_once_per_instance(): + """Second invocation on the same manager instance suppresses the context_window_limit warning.""" + manager = _MinimalManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=None) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=150_000) + with patch("strands.agent.conversation_manager.conversation_manager.logger") as mock_logger: + registry.invoke_callbacks(event) + registry.invoke_callbacks(event) + assert mock_logger.warning.call_count == 1 + + +def test_proactive_compression_exception_swallowed(): + """Exceptions in reduce_context during proactive compression should not propagate.""" + + class _FailingManager(ConversationManager): + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + raise RuntimeError("boom") + + manager = _FailingManager(proactive_compression={"compression_threshold": 0.7}) + agent = _make_mock_agent(context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + +def test_proactive_compression_true_default_threshold_behavior(): + """proactive_compression=True uses 0.7 — triggered at 0.7+ but not below.""" + manager = _MinimalManager(proactive_compression=True) + agent = _make_mock_agent( + messages=[{"role": "user", "content": [{"text": "msg"}]}], context_window_limit=1000 + ) + registry = HookRegistry() + manager.register_hooks(registry) + + # 650/1000 = 0.65 < 0.7 — should NOT trigger + event = _make_threshold_event(agent, projected_input_tokens=650) + registry.invoke_callbacks(event) + assert manager.reduce_context_call_count == 0 + + # 800/1000 = 0.8 >= 0.7 — should trigger + event2 = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event2) + assert manager.reduce_context_call_count == 1 + + +def test_sliding_window_proactive_compression_trims(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(6) + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=800) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 4 + + +def test_sliding_window_proactive_compression_no_trim_below(): + manager = SlidingWindowConversationManager( + window_size=4, should_truncate_results=False, proactive_compression={"compression_threshold": 0.7} + ) + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + agent = _make_mock_agent(messages=messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = _make_threshold_event(agent, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 2 diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index c49c69de6..dbd225e9b 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -1,5 +1,5 @@ from typing import cast -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest @@ -8,6 +8,8 @@ DEFAULT_SUMMARIZATION_PROMPT, SummarizingConversationManager, ) +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookRegistry from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -101,7 +103,7 @@ def test_init_clamps_summary_ratio(): def test_reduce_context_raises_when_no_agent(): - """Test that reduce_context raises exception when agent has no messages.""" + """Test that reduce_context raises exception when agent has no messages (reactive mode).""" manager = SummarizingConversationManager() # Create a mock agent with no messages @@ -109,8 +111,9 @@ def test_reduce_context_raises_when_no_agent(): empty_messages: Messages = [] mock_agent.messages = empty_messages + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_with_summarization(summarizing_manager, mock_agent): @@ -155,8 +158,9 @@ def test_reduce_context_too_few_messages_raises_exception(summarizing_manager, m ] mock_agent.messages = insufficient_test_messages # 5 messages, preserve_recent_messages=5, so nothing to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_insufficient_messages_for_summarization(mock_agent): @@ -173,9 +177,9 @@ def test_reduce_context_insufficient_messages_for_summarization(mock_agent): ] mock_agent.messages = insufficient_messages - # This should raise an exception since there aren't enough messages to summarize + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_reduce_context_raises_on_summarization_failure(): @@ -197,8 +201,9 @@ def test_reduce_context_raises_on_summarization_failure(): ) with patch("strands.agent.conversation_manager.summarizing_conversation_manager.logger") as mock_logger: + # Reactive mode (e is set) should raise with pytest.raises(Exception, match="Agent failed"): - manager.reduce_context(failing_agent) + manager.reduce_context(failing_agent, e=RuntimeError("overflow")) # Should log the error mock_logger.error.assert_called_once() @@ -675,9 +680,10 @@ def mock_adjust(messages, split_point): ] mock_agent.messages = simple_messages - # The adjustment method will return 0, which should trigger line 122-123 + # The adjustment method will return 0, which should trigger the <= 0 check + # Reactive mode (e is set) should raise with pytest.raises(ContextWindowOverflowException, match="insufficient messages for summarization"): - manager.reduce_context(mock_agent) + manager.reduce_context(mock_agent, e=RuntimeError("overflow")) def test_summarizing_conversation_manager_properly_records_removed_message_count(): @@ -802,3 +808,86 @@ def tracking_call(self, prompt): assert observed_values == [None], "structured output should be disabled during summarization" assert summary_agent._default_structured_output_model is structured_output_model, "should be restored after" + + +# ============================================================================== +# Compression Threshold Tests +# ============================================================================== + + +def _make_summarizing_threshold_agent(messages, summary_response="Summary of conversation", context_window_limit=1000): + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = context_window_limit + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(summary_response)) + return agent + + +def test_proactive_compression_summarizes_when_exceeded(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + registry.invoke_callbacks(event) + + # 20 * 0.5 = 10 summarized → 1 summary + 10 remaining = 11 + assert len(agent.messages) == 11 + assert agent.messages[0]["role"] == "user" + + +def test_proactive_compression_no_summarize_when_below(): + manager = SummarizingConversationManager(proactive_compression={"compression_threshold": 0.7}) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = _make_summarizing_threshold_agent(messages, context_window_limit=1000) + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=500) + registry.invoke_callbacks(event) + + assert len(agent.messages) == 20 + + +def test_proactive_compression_swallows_errors(): + manager = SummarizingConversationManager( + summary_ratio=0.5, + preserve_recent_messages=2, + proactive_compression={"compression_threshold": 0.7}, + ) + messages = [ + {"role": "user", "content": [{"text": f"Message {i}"}]} + if i % 2 == 0 + else {"role": "assistant", "content": [{"text": f"Response {i}"}]} + for i in range(20) + ] + agent = MagicMock() + agent.messages = messages + agent.model = MagicMock() + agent.model.context_window_limit = 1000 + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(RuntimeError("model failed"))) + + registry = HookRegistry() + manager.register_hooks(registry) + + event = BeforeModelCallEvent(agent=agent, invocation_state={}, projected_input_tokens=800) + # Should not throw — proactive compression is best-effort + registry.invoke_callbacks(event) + assert len(agent.messages) == 20 From 1847faec4fd37d2458156b6147c996826259377a Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Mon, 11 May 2026 11:09:27 -0400 Subject: [PATCH 270/279] feat: cache AccessDenied error for count tokens (#2279) --- src/strands/models/bedrock.py | 25 +++++++++---- tests/strands/models/test_bedrock.py | 54 ++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c74a63a3b..b9aae9b06 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -55,13 +55,13 @@ "anthropic.claude", ] -# Cache of model IDs that do not support the CountTokens API. -_UNSUPPORTED_COUNT_TOKENS_MODELS: set[str] = set() +# Cache of model IDs for which CountTokens API calls should be skipped. +_SKIP_COUNT_TOKENS_MODELS: set[str] = set() -def _clear_unsupported_count_tokens_cache() -> None: - """Clear the cache of model IDs that do not support the CountTokens API.""" - _UNSUPPORTED_COUNT_TOKENS_MODELS.clear() +def _clear_skip_count_tokens_cache() -> None: + """Clear the cache of model IDs for which CountTokens API calls should be skipped.""" + _SKIP_COUNT_TOKENS_MODELS.clear() T = TypeVar("T", bound=BaseModel) @@ -803,7 +803,7 @@ async def count_tokens( model_id: str = self.config["model_id"] - if model_id in _UNSUPPORTED_COUNT_TOKENS_MODELS: + if model_id in _SKIP_COUNT_TOKENS_MODELS: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: @@ -833,6 +833,17 @@ async def count_tokens( return total_tokens except Exception as e: if ( + isinstance(e, ClientError) + and e.response.get("Error", {}).get("Code") == "AccessDeniedException" + ): + logger.warning( + "model_id=<%s> | bedrock:CountTokens permission denied," + " falling back to heuristic estimation: %s", + model_id, + e, + ) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) + elif ( isinstance(e, ClientError) and e.response.get("Error", {}).get("Code") == "ValidationException" and "doesn't support counting tokens" in str(e) @@ -842,7 +853,7 @@ async def count_tokens( " falling back to estimation", model_id, ) - _UNSUPPORTED_COUNT_TOKENS_MODELS.add(model_id) + _SKIP_COUNT_TOKENS_MODELS.add(model_id) else: logger.debug( "model_id=<%s>, error=<%s> | native token counting failed, falling back to estimation", diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f1f7d1f1..b65d77234 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,7 +19,7 @@ DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, - _clear_unsupported_count_tokens_cache, + _clear_skip_count_tokens_cache, ) from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from strands.types.tools import ToolSpec @@ -3336,9 +3336,9 @@ class TestCountTokens: @pytest.fixture(autouse=True) def clean_cache(self): - _clear_unsupported_count_tokens_cache() + _clear_skip_count_tokens_cache() yield - _clear_unsupported_count_tokens_cache() + _clear_skip_count_tokens_cache() @pytest.fixture def model_with_client(self, bedrock_client, model_id): @@ -3473,6 +3473,54 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien await model.count_tokens(messages=messages) assert bedrock_client.count_tokens.call_count == 1 + @pytest.mark.asyncio + async def test_caches_model_id_when_access_denied(self, bedrock_client, messages): + model = BedrockModel(model_id="access-denied-cache-test-model") + bedrock_client.count_tokens.side_effect = ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens", + } + }, + "CountTokens", + ) + + # First call: hits API, gets error, caches + await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_called_once() + + # Reset mock to clearly verify second call doesn't hit the API + bedrock_client.count_tokens.reset_mock() + + # Second call: skips API entirely due to caching + result = await model.count_tokens(messages=messages) + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + + @pytest.mark.asyncio + async def test_access_denied_logs_warning_with_full_error( + self, model_with_client, bedrock_client, messages, caplog + ): + error_message = ( + "User: arn:aws:sts::123456789012:assumed-role/role is not authorized" + " to perform: bedrock:CountTokens" + ) + bedrock_client.count_tokens.side_effect = ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": error_message}}, + "CountTokens", + ) + + with caplog.at_level(logging.WARNING, logger="strands.models.bedrock"): + await model_with_client.count_tokens(messages=messages) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 1 + assert "bedrock:CountTokens permission denied" in warning_records[0].message + assert error_message in warning_records[0].message + @pytest.mark.asyncio async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): model = BedrockModel(model_id="transient-error-test-model") From b1a3f037c561422d6bd084a22e2e8bc515e23411 Mon Sep 17 00:00:00 2001 From: "Vimal Paliwal (vim)" Date: Tue, 12 May 2026 16:52:17 +0100 Subject: [PATCH 271/279] fix(ollama): update return type of latencyMs metric for ollama model provider (#2236) --- src/strands/models/ollama.py | 2 +- tests/strands/models/test_ollama.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 54805ac16..cf7108c3a 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -280,7 +280,7 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, }, "metrics": { - "latencyMs": event["data"].total_duration / 1e6, + "latencyMs": int(event["data"].total_duration / 1e6), }, }, } diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index 7a6bbf97c..360683d08 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -407,7 +407,7 @@ def test_format_chunk_metadata(model): "totalTokens": 150, }, "metrics": { - "latencyMs": 1.0, + "latencyMs": 1, }, }, } @@ -447,7 +447,7 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings { "metadata": { "usage": {"inputTokens": 5, "outputTokens": 10, "totalTokens": 15}, - "metrics": {"latencyMs": 1.0}, + "metrics": {"latencyMs": 1}, } }, ] @@ -525,7 +525,7 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist): assert tru_events[8] == { "metadata": { "usage": {"inputTokens": 8, "outputTokens": 15, "totalTokens": 23}, - "metrics": {"latencyMs": 2.0}, + "metrics": {"latencyMs": 2}, } } expected_request = { From 6b539285c85e2af78ee9fa877295b10b9cf055d9 Mon Sep 17 00:00:00 2001 From: Albert Zhao <67480168+Albertozhao@users.noreply.github.com> Date: Tue, 12 May 2026 17:52:01 -0400 Subject: [PATCH 272/279] feat: add official Discord link (#2285) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 173adc006..7e1612858 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ License PyPI version Python versions + Strands Discord

@@ -316,6 +317,9 @@ We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for deta - Code of Conduct - Reporting of security issues +## Stay in touch with the team +Come meet the Strands team and other users on [**Discord**](https://discord.com/invite/strands) + ## License This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. From 305a0051b25e2425c4388f9b1faf7dc6c48e9b20 Mon Sep 17 00:00:00 2001 From: opieter-aws Date: Wed, 13 May 2026 09:09:03 -0400 Subject: [PATCH 273/279] fix: set use_native_token_count default to false (#2284) --- src/strands/models/anthropic.py | 6 +++--- src/strands/models/bedrock.py | 6 +++--- src/strands/models/gemini.py | 6 +++--- src/strands/models/llamacpp.py | 6 +++--- src/strands/models/openai_responses.py | 6 +++--- tests/strands/models/test_anthropic.py | 13 ++++++++++++- tests/strands/models/test_bedrock.py | 19 +++++++++++++++---- tests/strands/models/test_gemini.py | 13 ++++++++++++- tests/strands/models/test_llamacpp.py | 13 ++++++++++++- tests/strands/models/test_openai_responses.py | 13 ++++++++++++- 10 files changed, 78 insertions(+), 23 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 04fae220d..812171a0c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -58,8 +58,8 @@ class AnthropicConfig(BaseModelConfig, total=False): params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://docs.anthropic.com/en/api/messages. use_native_token_count: Whether to use the native Anthropic count_tokens API. - When True (default), count_tokens() calls the Anthropic API for accurate counts. - When False, skips the API call and uses the local estimator. + When True, count_tokens() calls the Anthropic API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ max_tokens: Required[int] @@ -398,7 +398,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("use_native_token_count") is False: + if self.config.get("use_native_token_count") is not True: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b9aae9b06..ab9adb67a 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -118,8 +118,8 @@ class BedrockConfig(BaseModelConfig, total=False): temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) use_native_token_count: Whether to use the native Bedrock CountTokens API. - When True (default), count_tokens() calls the Bedrock API for accurate counts. - When False, skips the API call and uses the local estimator. + When True, count_tokens() calls the Bedrock API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ additional_args: dict[str, Any] | None @@ -798,7 +798,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("use_native_token_count") is False: + if self.config.get("use_native_token_count") is not True: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) model_id: str = self.config["model_id"] diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 8ed579d38..43e4f0349 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -50,8 +50,8 @@ class GeminiConfig(BaseModelConfig, total=False): For a complete list of supported tools, see https://ai.google.dev/api/caching#Tool use_native_token_count: Whether to use the native Gemini count_tokens API. - When True (default), count_tokens() calls the Gemini API for accurate counts. - When False, skips the API call and uses the local estimator. + When True, count_tokens() calls the Gemini API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ model_id: Required[str] @@ -461,7 +461,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("use_native_token_count") is False: + if self.config.get("use_native_token_count") is not True: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 531cf6b50..5dd25729d 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -126,8 +126,8 @@ class LlamaCppConfig(BaseModelConfig, total=False): - slot_id: Slot ID for parallel inference - samplers: Custom sampler order use_native_token_count: Whether to use the native llama.cpp /tokenize endpoint. - When True (default), count_tokens() calls the server's tokenize endpoint for accurate counts. - When False, skips the API call and uses the local estimator. + When True, count_tokens() calls the server's tokenize endpoint for accurate counts. + When False (default), skips the API call and uses the local estimator. """ model_id: str @@ -537,7 +537,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("use_native_token_count") is False: + if self.config.get("use_native_token_count") is not True: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index c6ddbb9d6..8914fb01c 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -137,8 +137,8 @@ class OpenAIResponsesConfig(BaseModelConfig, total=False): When True, the server stores conversation history and the client does not need to send the full message history with each request. Defaults to False. use_native_token_count: Whether to use the native OpenAI input_tokens.count API. - When True (default), count_tokens() calls the OpenAI API for accurate counts. - When False, skips the API call and uses the local estimator. + When True, count_tokens() calls the OpenAI API for accurate counts. + When False (default), skips the API call and uses the local estimator. """ model_id: str @@ -242,7 +242,7 @@ async def count_tokens( Returns: Total input token count. """ - if self.config.get("use_native_token_count") is False: + if self.config.get("use_native_token_count") is not True: return await super().count_tokens(messages, tool_specs, system_prompt, system_prompt_content) try: diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 6de821e90..0ebdb161c 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1072,7 +1072,7 @@ class TestCountTokens: @pytest.fixture def model_with_client(self, anthropic_client, model_id, max_tokens): _ = anthropic_client - return AnthropicModel(model_id=model_id, max_tokens=max_tokens) + return AnthropicModel(model_id=model_id, max_tokens=max_tokens, use_native_token_count=True) @pytest.fixture def messages(self): @@ -1175,3 +1175,14 @@ async def test_skip_native_api_when_use_native_token_count_false( anthropic_client.messages.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, anthropic_client, model_id, max_tokens, messages): + _ = anthropic_client + model = AnthropicModel(model_id=model_id, max_tokens=max_tokens) + + result = await model.count_tokens(messages=messages) + + anthropic_client.messages.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index b65d77234..2e105d64a 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3343,7 +3343,7 @@ def clean_cache(self): @pytest.fixture def model_with_client(self, bedrock_client, model_id): _ = bedrock_client - return BedrockModel(model_id=model_id) + return BedrockModel(model_id=model_id, use_native_token_count=True) @pytest.fixture def messages(self): @@ -3459,7 +3459,7 @@ async def test_fallback_logs_debug(self, model_with_client, bedrock_client, mess @pytest.mark.asyncio async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_client, messages): - model = BedrockModel(model_id="unsupported-cache-test-model") + model = BedrockModel(model_id="unsupported-cache-test-model", use_native_token_count=True) bedrock_client.count_tokens.side_effect = ClientError( {"Error": {"Code": "ValidationException", "Message": "The provided model doesn't support counting tokens"}}, "CountTokens", @@ -3475,7 +3475,7 @@ async def test_caches_model_id_when_count_tokens_unsupported(self, bedrock_clien @pytest.mark.asyncio async def test_caches_model_id_when_access_denied(self, bedrock_client, messages): - model = BedrockModel(model_id="access-denied-cache-test-model") + model = BedrockModel(model_id="access-denied-cache-test-model", use_native_token_count=True) bedrock_client.count_tokens.side_effect = ClientError( { "Error": { @@ -3523,7 +3523,7 @@ async def test_access_denied_logs_warning_with_full_error( @pytest.mark.asyncio async def test_does_not_cache_model_id_for_other_errors(self, bedrock_client, messages): - model = BedrockModel(model_id="transient-error-test-model") + model = BedrockModel(model_id="transient-error-test-model", use_native_token_count=True) bedrock_client.count_tokens.side_effect = RuntimeError("Transient network error") await model.count_tokens(messages=messages) @@ -3543,3 +3543,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, bedrock_c bedrock_client.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, bedrock_client, model_id, messages): + _ = bedrock_client + model = BedrockModel(model_id=model_id) + + result = await model.count_tokens(messages=messages) + + bedrock_client.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index b846bfcdf..a8ff38b99 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -1144,7 +1144,7 @@ def gemini_client(self): @pytest.fixture def model(self, gemini_client): _ = gemini_client - return GeminiModel(model_id="m1") + return GeminiModel(model_id="m1", use_native_token_count=True) @pytest.fixture def messages(self): @@ -1239,3 +1239,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, gemini_cl gemini_client.aio.models.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, gemini_client, messages): + _ = gemini_client + model = GeminiModel(model_id="m1") + + result = await model.count_tokens(messages=messages) + + gemini_client.aio.models.count_tokens.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_llamacpp.py b/tests/strands/models/test_llamacpp.py index 43fb03629..6868e490b 100644 --- a/tests/strands/models/test_llamacpp.py +++ b/tests/strands/models/test_llamacpp.py @@ -713,7 +713,7 @@ class TestCountTokens: @pytest.fixture def model(self): - return LlamaCppModel(base_url="http://localhost:8080") + return LlamaCppModel(base_url="http://localhost:8080", use_native_token_count=True) @pytest.fixture def messages(self): @@ -814,3 +814,14 @@ async def test_skip_native_api_when_use_native_token_count_false(self, messages) model.client.post.assert_not_called() assert isinstance(result, int) assert result >= 0 + + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, messages): + model = LlamaCppModel(base_url="http://localhost:8080") + model.client.post = AsyncMock() + + result = await model.count_tokens(messages=messages) + + model.client.post.assert_not_called() + assert isinstance(result, int) + assert result >= 0 diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 47acfded4..697508339 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -1224,7 +1224,7 @@ def openai_client(self): @pytest.fixture def model(self, openai_client): _ = openai_client - return OpenAIResponsesModel(model_id="gpt-4o") + return OpenAIResponsesModel(model_id="gpt-4o", use_native_token_count=True) @pytest.fixture def messages(self): @@ -1329,6 +1329,17 @@ async def test_skip_native_api_when_use_native_token_count_false(self, openai_cl assert isinstance(result, int) assert result >= 0 + @pytest.mark.asyncio + async def test_skip_native_api_by_default(self, openai_client, messages): + _ = openai_client + model = OpenAIResponsesModel(model_id="gpt-4o") + + result = await model.count_tokens(messages=messages) + + openai_client.responses.input_tokens.count.assert_not_called() + assert isinstance(result, int) + assert result >= 0 + # ============================================================================= # Bedrock Mantle (bedrock_mantle_config) integration with OpenAIResponsesModel From fa74d803eab2f9fa46886b679580d16336676620 Mon Sep 17 00:00:00 2001 From: mehtarac Date: Wed, 13 May 2026 10:28:13 -0400 Subject: [PATCH 274/279] fix: swarm bug "Failed to detach context" with opentelemetry (#2281) --- src/strands/multiagent/swarm.py | 34 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index f5731a371..b0f9f4a86 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -17,6 +17,7 @@ import copy import json import logging +import sys import time from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field @@ -439,28 +440,25 @@ async def _stream_with_timeout( Exception: If total execution time exceeds timeout """ if timeout is None: - # No timeout - just pass through async for event in async_generator: yield event + elif sys.version_info >= (3, 11): + try: + async with asyncio.timeout(timeout): + async for event in async_generator: + yield event + except asyncio.TimeoutError as err: + raise Exception(timeout_message) from err else: - # Track start time for total timeout - start_time = asyncio.get_event_loop().time() - - while True: - # Calculate remaining time from total timeout budget - elapsed = asyncio.get_event_loop().time() - start_time - remaining = timeout - elapsed - - if remaining <= 0: + # Python 3.10 fallback: timeout is only checked between yielded events. + # A generator that hangs mid-await won't be interrupted until the next event. + # Remove once Python 3.10 support is dropped (Oct 2026). + start_time = asyncio.get_running_loop().time() + async for event in async_generator: + elapsed = asyncio.get_running_loop().time() - start_time + if elapsed > timeout: raise Exception(timeout_message) - - try: - event = await asyncio.wait_for(async_generator.__anext__(), timeout=remaining) - yield event - except StopAsyncIteration: - break - except asyncio.TimeoutError as err: - raise Exception(timeout_message) from err + yield event def _setup_swarm(self, nodes: list[Agent]) -> None: """Initialize swarm configuration.""" From afb0dd936bf3e72753bbaa7ff10a067afaf90fdf Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Thu, 14 May 2026 13:22:35 -0400 Subject: [PATCH 275/279] feat(plugins): add MultiAgentPlugin for Swarm and Graph orchestrators (#2280) Co-authored-by: Mackenzie Zastrow --- AGENTS.md | 5 +- src/strands/__init__.py | 3 +- src/strands/multiagent/base.py | 15 + src/strands/multiagent/graph.py | 33 +- src/strands/multiagent/swarm.py | 22 +- src/strands/plugins/__init__.py | 7 +- src/strands/plugins/_discovery.py | 103 ++++ src/strands/plugins/multiagent_plugin.py | 119 ++++ src/strands/plugins/multiagent_registry.py | 113 ++++ src/strands/plugins/plugin.py | 40 +- src/strands/plugins/registry.py | 19 +- .../multiagent/test_multiagent_plugins.py | 283 +++++++++ .../strands/plugins/test_multiagent_plugin.py | 563 ++++++++++++++++++ 13 files changed, 1274 insertions(+), 51 deletions(-) create mode 100644 src/strands/plugins/_discovery.py create mode 100644 src/strands/plugins/multiagent_plugin.py create mode 100644 src/strands/plugins/multiagent_registry.py create mode 100644 tests/strands/multiagent/test_multiagent_plugins.py create mode 100644 tests/strands/plugins/test_multiagent_plugin.py diff --git a/AGENTS.md b/AGENTS.md index 0b877ea98..daddbbb2d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -130,8 +130,11 @@ strands-agents/ │ │ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class +│ │ ├── multiagent_plugin.py # MultiAgentPlugin base class │ │ ├── decorator.py # @hook decorator -│ │ └── registry.py # PluginRegistry for tracking plugins +│ │ ├── registry.py # PluginRegistry for tracking agent plugins +│ │ ├── multiagent_registry.py # Registry for tracking orchestrator plugins +│ │ └── _discovery.py # Shared hook/tool discovery utilities │ │ │ ├── handlers/ # Event handlers │ │ └── callback_handler.py # Callback handling diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 6625ac41f..00e32ead3 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import MultiAgentPlugin, Plugin from .tools.decorator import tool from .types._snapshot import Snapshot from .types.tools import ToolContext @@ -17,6 +17,7 @@ "agent", "models", "ModelRetryStrategy", + "MultiAgentPlugin", "Plugin", "Skill", "Snapshot", diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index dc3258f68..14c4d0d14 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -13,6 +13,7 @@ from .._async import run_async from ..agent import AgentResult +from ..hooks.registry import HookCallback from ..interrupt import Interrupt from ..types.event_loop import Metrics, Usage from ..types.multiagent import MultiAgentInput @@ -254,6 +255,20 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the orchestrator. + + Subclasses that support hooks should override this method to register + the callback with their hook registry. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + raise NotImplementedError(f"{type(self).__name__} must implement add_hook() to support plugins") + def _parse_trace_attributes( self, attributes: Mapping[str, AttributeValue] | None = None ) -> dict[str, AttributeValue]: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8da8314ea..146a31563 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -35,8 +35,10 @@ BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks.registry import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( @@ -253,6 +255,7 @@ def __init__(self) -> None: self._id: str = _DEFAULT_GRAPH_ID self._session_manager: SessionManager | None = None self._hooks: list[HookProvider] | None = None + self._plugins: list[MultiAgentPlugin] | None = None def add_node(self, executor: AgentBase | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an AgentBase or MultiAgentBase instance as a node to the graph.""" @@ -370,6 +373,15 @@ def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": self._hooks = hooks return self + def set_plugins(self, plugins: list[MultiAgentPlugin]) -> "GraphBuilder": + """Set plugins for the graph. + + Args: + plugins: List of multi-agent plugins for extending graph behavior + """ + self._plugins = plugins + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -398,6 +410,7 @@ def build(self) -> "Graph": session_manager=self._session_manager, hooks=self._hooks, id=self._id, + plugins=self._plugins, ) def _validate_graph(self) -> None: @@ -429,6 +442,7 @@ def __init__( hooks: list[HookProvider] | None = None, id: str = _DEFAULT_GRAPH_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -444,6 +458,7 @@ def __init__( hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) id: Unique graph id (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending graph behavior (default: None) """ super().__init__() @@ -469,12 +484,28 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_next_nodes: list[GraphNode] = [] self._resume_from_session = False self.id = id run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the graph. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> GraphResult: diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index b0f9f4a86..2eeb38694 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -35,8 +35,10 @@ BeforeNodeCallEvent, MultiAgentInitializedEvent, ) -from ..hooks.registry import HookProvider, HookRegistry +from ..hooks.registry import HookCallback, HookProvider, HookRegistry from ..interrupt import Interrupt, _InterruptState +from ..plugins.multiagent_plugin import MultiAgentPlugin +from ..plugins.multiagent_registry import _MultiAgentPluginRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool @@ -250,6 +252,7 @@ def __init__( hooks: list[HookProvider] | None = None, id: str = _DEFAULT_SWARM_ID, trace_attributes: Mapping[str, AttributeValue] | None = None, + plugins: list[MultiAgentPlugin] | None = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -268,6 +271,7 @@ def __init__( session_manager: Session manager for persisting graph state and execution history (default: None) hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) trace_attributes: Custom trace attributes to apply to the agent's trace span (default: None) + plugins: List of multi-agent plugins for extending swarm behavior (default: None) """ super().__init__() self.id = id @@ -300,12 +304,28 @@ def __init__( if self.session_manager: self.hooks.add_hook(self.session_manager) + self._plugin_registry = _MultiAgentPluginRegistry(self) + if plugins: + for plugin in plugins: + self._plugin_registry.add_and_init(plugin) + self._resume_from_session = False self._setup_swarm(nodes) self._inject_swarm_tools() run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) + def add_hook(self, callback: HookCallback, event_type: type | list[type] | None = None) -> None: + """Register a hook callback with the swarm. + + Args: + callback: The callback function to invoke when events of this type occur. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. + """ + self.hooks.add_callback(event_type, callback) + def __call__( self, task: MultiAgentInput, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index c4b7c72c7..7a3d5fa17 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,13 +1,16 @@ -"""Plugin system for extending agent functionality. +"""Plugin system for extending agent and orchestrator functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through automatic hook and tool registration. +extend agent and multi-agent orchestrator behavior through automatic hook +and tool registration. """ from .decorator import hook +from .multiagent_plugin import MultiAgentPlugin from .plugin import Plugin __all__ = [ + "MultiAgentPlugin", "Plugin", "hook", ] diff --git a/src/strands/plugins/_discovery.py b/src/strands/plugins/_discovery.py new file mode 100644 index 000000000..eda955030 --- /dev/null +++ b/src/strands/plugins/_discovery.py @@ -0,0 +1,103 @@ +"""Shared utility for discovering decorated methods on plugin instances. + +This module provides helper functions used by both Plugin and MultiAgentPlugin +to scan for @hook (and optionally @tool) decorated methods, and shared registry +utilities for plugin initialization and hook registration. +""" + +import inspect +import logging +from collections.abc import Awaitable, Callable +from typing import Any, cast + +from .._async import run_async +from ..hooks.registry import HookCallback +from ..tools.decorator import DecoratedFunctionTool + +logger = logging.getLogger(__name__) + + +def _discover_methods(instance: object, plugin_name: str, predicate: Callable[[object], bool], label: str) -> list[Any]: + """Scan an instance's class hierarchy for methods matching a predicate. + + Walks the MRO in reverse so parent class methods come first, but child + overrides win (only the child's version is included). + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + predicate: Function that returns True for attributes to collect. + label: Label for debug logging (e.g., "hook", "tool"). + + Returns: + List of matching bound methods/descriptors in declaration order. + """ + results: list[Any] = [] + seen: set[str] = set() + + for cls in reversed(type(instance).__mro__): + for attr_name in cls.__dict__: + if attr_name in seen: + continue + seen.add(attr_name) + + try: + bound = getattr(instance, attr_name) + except Exception: + continue + + if predicate(bound): + results.append(bound) + logger.debug("plugin=<%s>, %s=<%s> | discovered", plugin_name, label, attr_name) + + return results + + +def discover_hooks(instance: object, plugin_name: str) -> list[HookCallback]: + """Scan an instance's class hierarchy for @hook decorated methods. + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of bound hook callback methods in declaration order. + """ + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: hasattr(bound, "_hook_event_types") and callable(bound), + label="hook", + ) + + +def discover_tools(instance: object, plugin_name: str) -> list[DecoratedFunctionTool]: + """Scan an instance's class hierarchy for @tool decorated methods. + + Args: + instance: The plugin instance to scan. + plugin_name: The plugin name (used for debug logging). + + Returns: + List of DecoratedFunctionTool instances in declaration order. + """ + return _discover_methods( + instance, + plugin_name, + predicate=lambda bound: isinstance(bound, DecoratedFunctionTool), + label="tool", + ) + + +def call_init_method(init_method: Callable[..., Any], target: Any) -> None: + """Call a plugin's init method, handling both sync and async implementations. + + Args: + init_method: The init_agent or init_multi_agent method to call. + target: The agent or orchestrator instance to pass to the init method. + """ + if inspect.iscoroutinefunction(init_method): + async_init = cast(Callable[..., Awaitable[None]], init_method) + run_async(lambda: async_init(target)) + else: + init_method(target) diff --git a/src/strands/plugins/multiagent_plugin.py b/src/strands/plugins/multiagent_plugin.py new file mode 100644 index 000000000..89bd9e0e5 --- /dev/null +++ b/src/strands/plugins/multiagent_plugin.py @@ -0,0 +1,119 @@ +"""MultiAgentPlugin base class for extending multi-agent orchestrator functionality. + +This module defines the MultiAgentPlugin base class, which provides a composable way to +add behavior changes to multi-agent orchestrators (Swarm, Graph) through automatic hook +registration and custom initialization. + +MultiAgentPlugin is the orchestrator-level counterpart to Plugin (which targets individual agents). +A class can implement both Plugin and MultiAgentPlugin to provide functionality at both levels. +""" + +from abc import ABC, abstractmethod +from collections.abc import Awaitable +from typing import TYPE_CHECKING + +from ..hooks.registry import HookCallback +from ._discovery import discover_hooks + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + + +class MultiAgentPlugin(ABC): + """Base class for objects that extend multi-agent orchestrator functionality. + + MultiAgentPlugins provide a composable way to add behavior changes to orchestrators + (Swarm, Graph). They support automatic discovery and registration of methods decorated + with @hook. + + Unlike agent-level Plugin, MultiAgentPlugin does not support @tool decorated methods + since orchestrators do not have tool registries. + + Attributes: + name: A stable string identifier for the plugin (must be provided by subclass) + hooks: Hooks attached to the orchestrator, auto-discovered from @hook decorated methods + + Example using decorators (recommended): + ```python + from strands.plugins import MultiAgentPlugin, hook + from strands.hooks import BeforeNodeCallEvent, AfterNodeCallEvent + + class MonitoringPlugin(MultiAgentPlugin): + name = "monitoring" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + print(f"Node {event.node_id} completed") + ``` + + Example with custom initialization: + ```python + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom initialization logic + pass + ``` + + Dual-use example (both agent and orchestrator): + ```python + from strands.plugins import Plugin, MultiAgentPlugin, hook + from strands.hooks import BeforeInvocationEvent, BeforeNodeCallEvent + + class ObservabilityPlugin(Plugin, MultiAgentPlugin): + name = "observability" + + @hook + def on_agent_invocation(self, event: BeforeInvocationEvent): + print("Agent invocation started") + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + print(f"Node {event.node_id} starting") + + def init_agent(self, agent): + pass # Agent-level setup + + def init_multi_agent(self, orchestrator): + pass # Orchestrator-level setup + ``` + """ + + @property + @abstractmethod + def name(self) -> str: + """A stable string identifier for the plugin.""" + ... + + def __init__(self) -> None: + """Initialize the plugin and discover decorated hook methods. + + Scans the class for methods decorated with @hook and stores references + for later registration when the plugin is attached to an orchestrator. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). + """ + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + + @property + def hooks(self) -> list[HookCallback]: + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" + return self._hooks + + def init_multi_agent(self, orchestrator: "MultiAgentBase") -> None | Awaitable[None]: + """Initialize the plugin with the orchestrator instance. + + Override this method to add custom initialization logic. Decorated + hooks are automatically registered by the plugin registry. + + Args: + orchestrator: The multi-agent orchestrator instance to initialize with. + """ + return None diff --git a/src/strands/plugins/multiagent_registry.py b/src/strands/plugins/multiagent_registry.py new file mode 100644 index 000000000..365c8f9c5 --- /dev/null +++ b/src/strands/plugins/multiagent_registry.py @@ -0,0 +1,113 @@ +"""MultiAgentPlugin registry for managing plugins attached to a multi-agent orchestrator. + +This module provides the _MultiAgentPluginRegistry class for tracking and managing +plugins that have been initialized with an orchestrator instance. +""" + +import logging +import weakref +from typing import TYPE_CHECKING + +from ._discovery import call_init_method +from .multiagent_plugin import MultiAgentPlugin + +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + +logger = logging.getLogger(__name__) + + +class _MultiAgentPluginRegistry: + """Registry for managing plugins attached to a multi-agent orchestrator. + + The _MultiAgentPluginRegistry tracks plugins that have been initialized with an + orchestrator, providing methods to add plugins and invoke their initialization. + + The registry handles: + 1. Calling the plugin's init_multi_agent() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the orchestrator + + Example: + ```python + registry = _MultiAgentPluginRegistry(orchestrator) + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_event(self, event: BeforeNodeCallEvent): + pass # Auto-registered by registry + + def init_multi_agent(self, orchestrator: MultiAgentBase) -> None: + # Custom logic + pass + + plugin = MyPlugin() + registry.add_and_init(plugin) + ``` + """ + + def __init__(self, orchestrator: "MultiAgentBase") -> None: + """Initialize a plugin registry with an orchestrator reference. + + Args: + orchestrator: The orchestrator instance that plugins will be initialized with. + """ + self._orchestrator_ref = weakref.ref(orchestrator) + self._plugins: dict[str, MultiAgentPlugin] = {} + + @property + def _orchestrator(self) -> "MultiAgentBase": + """Return the orchestrator, raising ReferenceError if it has been garbage collected.""" + orchestrator = self._orchestrator_ref() + if orchestrator is None: + raise ReferenceError("Orchestrator has been garbage collected") + return orchestrator + + def add_and_init(self, plugin: MultiAgentPlugin) -> None: + """Add and initialize a plugin with the orchestrator. + + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_multi_agent method for custom initialization + 3. Auto-registers all discovered @hook methods with the orchestrator's hook registry + + Handles both sync and async init_multi_agent implementations automatically. + + Args: + plugin: The plugin to add and initialize. + + Raises: + ValueError: If a plugin with the same name is already registered. + """ + if plugin.name in self._plugins: + raise ValueError(f"plugin_name=<{plugin.name}> | plugin already registered") + + logger.debug("plugin_name=<%s> | registering and initializing multi-agent plugin", plugin.name) + self._plugins[plugin.name] = plugin + + # Call user's init_multi_agent for custom initialization + call_init_method(plugin.init_multi_agent, self._orchestrator) + + # Auto-register discovered hooks with the orchestrator + self._register_hooks(plugin) + + def _register_hooks(self, plugin: MultiAgentPlugin) -> None: + """Register all discovered hooks from the plugin with the orchestrator. + + Uses orchestrator.add_hook() so that the orchestrator can track + registrations through its public API. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._orchestrator.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index b670de297..35633a30e 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -4,19 +4,17 @@ add behavior changes to agents through automatic hook and tool registration. """ -import logging from abc import ABC, abstractmethod from collections.abc import Awaitable from typing import TYPE_CHECKING from ..hooks.registry import HookCallback from ..tools.decorator import DecoratedFunctionTool +from ._discovery import discover_hooks, discover_tools if TYPE_CHECKING: from ..agent import Agent -logger = logging.getLogger(__name__) - class Plugin(ABC): """Base class for objects that extend agent functionality. @@ -79,10 +77,14 @@ def __init__(self) -> None: Scans the class for methods decorated with @hook and @tool and stores references for later registration when the plugin is attached to an agent. + + Uses a guard to prevent double-discovery when used with multiple inheritance + (e.g., a class that inherits from both Plugin and MultiAgentPlugin). """ - self._hooks: list[HookCallback] = [] - self._tools: list[DecoratedFunctionTool] = [] - self._discover_decorated_methods() + if not hasattr(self, "_hooks"): + self._hooks: list[HookCallback] = discover_hooks(self, self.name) + if not hasattr(self, "_tools"): + self._tools: list[DecoratedFunctionTool] = discover_tools(self, self.name) @property def hooks(self) -> list[HookCallback]: @@ -94,32 +96,6 @@ def tools(self) -> list[DecoratedFunctionTool]: """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" return self._tools - def _discover_decorated_methods(self) -> None: - """Scan class for @hook and @tool decorated methods in declaration order.""" - seen: set[str] = set() - # Walk MRO so parent class hooks come first, child overrides win - for cls in reversed(type(self).__mro__): - for name in cls.__dict__: - if name in seen: - continue - seen.add(name) - - # Get the bound method from self - try: - bound = getattr(self, name) - except Exception: - continue - - # Check for @hook decorated methods - if hasattr(bound, "_hook_event_types") and callable(bound): - self._hooks.append(bound) - logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) - - # Check for @tool decorated methods (DecoratedFunctionTool instances) - if isinstance(bound, DecoratedFunctionTool): - self._tools.append(bound) - logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) - def init_agent(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the agent instance. diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index e994b5591..ca5d654c9 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -4,13 +4,11 @@ plugins that have been initialized with an agent instance. """ -import inspect import logging import weakref -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from .._async import run_async +from ._discovery import call_init_method from .plugin import Plugin if TYPE_CHECKING: @@ -91,13 +89,9 @@ def add_and_init(self, plugin: Plugin) -> None: self._plugins[plugin.name] = plugin # Call user's init_agent for custom initialization - if inspect.iscoroutinefunction(plugin.init_agent): - async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_agent) - run_async(lambda: async_plugin_init(self._agent)) - else: - plugin.init_agent(self._agent) + call_init_method(plugin.init_agent, self._agent) - # Auto-register discovered hooks with the agent's hook registry + # Auto-register discovered hooks with the agent self._register_hooks(plugin) # Auto-register discovered tools with the agent's tool registry @@ -106,9 +100,8 @@ def add_and_init(self, plugin: Plugin) -> None: def _register_hooks(self, plugin: Plugin) -> None: """Register all discovered hooks from the plugin with the agent. - Warns if a hook callback is already registered for an event type, - which can happen when init_agent() manually registers a hook that - is also decorated with @hook. + Uses agent.add_hook() rather than the hook registry directly, so that + the agent can track registrations through its public API. Args: plugin: The plugin whose hooks should be registered. diff --git a/tests/strands/multiagent/test_multiagent_plugins.py b/tests/strands/multiagent/test_multiagent_plugins.py new file mode 100644 index 000000000..85cc8d817 --- /dev/null +++ b/tests/strands/multiagent/test_multiagent_plugins.py @@ -0,0 +1,283 @@ +"""Tests for MultiAgentPlugin integration with Swarm and Graph.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from strands.hooks import BeforeNodeCallEvent +from strands.hooks.registry import HookProvider +from strands.multiagent import GraphBuilder, Swarm +from strands.multiagent.graph import Graph, GraphNode +from strands.plugins import MultiAgentPlugin, hook + +# --- Fixtures --- + + +@pytest.fixture +def mock_swarm_agent(): + """Create a mock agent suitable for Swarm construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.description = "Test agent" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + agent.tool_registry = MagicMock() + agent.tool_registry.get_all_tools_config.return_value = {} + return agent + + +@pytest.fixture +def mock_graph_agent(): + """Create a mock agent suitable for Graph construction.""" + agent = MagicMock() + agent.name = "agent1" + agent.messages = [] + agent.state = MagicMock() + agent.state.get.return_value = {} + agent._model_state = {} + agent._session_manager = None + return agent + + +def _make_swarm(agent, **kwargs): + """Helper to construct a Swarm with tracer patched out.""" + with patch("strands.multiagent.swarm.get_tracer"): + return Swarm(nodes=[agent], **kwargs) + + +def _make_graph(agent, **kwargs): + """Helper to construct a Graph with tracer patched out.""" + with patch("strands.multiagent.graph.get_tracer"): + node = GraphNode(node_id="agent1", executor=agent) + return Graph(nodes={"agent1": node}, edges=set(), entry_points={node}, **kwargs) + + +# --- Swarm plugin integration tests --- + + +def test_swarm_accepts_plugins_parameter(mock_swarm_agent): + """Test that Swarm constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert swarm._plugin_registry is not None + + +def test_swarm_initializes_plugins(mock_swarm_agent): + """Test that Swarm calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_swarm_registers_plugin_hooks(mock_swarm_agent): + """Test that Swarm registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_plugins_coexist_with_hooks(mock_swarm_agent): + """Test that plugins and legacy hooks parameter work together.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + swarm = _make_swarm(mock_swarm_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_swarm_duplicate_plugin_raises_error(mock_swarm_agent): + """Test that duplicate plugin names raise an error in Swarm.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + with pytest.raises(ValueError, match="plugin already registered"): + _make_swarm(mock_swarm_agent, plugins=[MyPlugin(), MyPlugin()]) + + +def test_swarm_no_plugins_parameter(mock_swarm_agent): + """Test that Swarm works without plugins parameter (backward compat).""" + swarm = _make_swarm(mock_swarm_agent) + assert swarm._plugin_registry is not None + + +# --- Graph plugin integration tests --- + + +def test_graph_builder_accepts_plugins(): + """Test that GraphBuilder has a set_plugins method.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + builder = GraphBuilder() + result = builder.set_plugins([MyPlugin()]) + assert result is builder + + +def test_graph_accepts_plugins_parameter(mock_graph_agent): + """Test that Graph constructor accepts a plugins parameter.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert graph._plugin_registry is not None + + +def test_graph_initializes_plugins(mock_graph_agent): + """Test that Graph calls init_multi_agent on plugins during construction.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert init_called + + +def test_graph_registers_plugin_hooks(mock_graph_agent): + """Test that Graph registers plugin hooks with its hook registry.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_plugins_coexist_with_hooks(mock_graph_agent): + """Test that plugins and legacy hooks parameter work together in Graph.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class MyHookProvider(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.on_before_node) + + def on_before_node(self, event): + pass + + graph = _make_graph(mock_graph_agent, plugins=[MyPlugin()], hooks=[MyHookProvider()]) + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 2 + + +def test_graph_builder_passes_plugins_to_graph(mock_graph_agent): + """Test that GraphBuilder.build() passes plugins to the Graph constructor.""" + init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def init_multi_agent(self, orchestrator): + nonlocal init_called + init_called = True + + with patch("strands.multiagent.graph.get_tracer"): + builder = GraphBuilder() + builder.add_node(mock_graph_agent, node_id="agent1") + builder.set_entry_point("agent1") + builder.set_plugins([MyPlugin()]) + graph = builder.build() + + assert init_called + assert graph._plugin_registry is not None + + +# --- add_hook method tests --- + + +def test_swarm_add_hook_registers_callback(mock_swarm_agent): + """Test that Swarm.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_registers_callback(mock_graph_agent): + """Test that Graph.add_hook registers a callback directly.""" + events_received = [] + + def on_before_node(event: BeforeNodeCallEvent): + events_received.append(event) + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node, BeforeNodeCallEvent) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_swarm_add_hook_infers_event_type(mock_swarm_agent): + """Test that Swarm.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + swarm = _make_swarm(mock_swarm_agent) + swarm.add_hook(on_before_node) + + assert len(swarm.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_graph_add_hook_infers_event_type(mock_graph_agent): + """Test that Graph.add_hook can infer event type from type hint.""" + + def on_before_node(event: BeforeNodeCallEvent): + pass + + graph = _make_graph(mock_graph_agent) + graph.add_hook(on_before_node) + + assert len(graph.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 diff --git a/tests/strands/plugins/test_multiagent_plugin.py b/tests/strands/plugins/test_multiagent_plugin.py new file mode 100644 index 000000000..b7e16c9eb --- /dev/null +++ b/tests/strands/plugins/test_multiagent_plugin.py @@ -0,0 +1,563 @@ +"""Tests for the MultiAgentPlugin base class and registry.""" + +import gc +import unittest.mock + +import pytest + +from strands.hooks import AfterNodeCallEvent, BeforeNodeCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.plugins.multiagent_plugin import MultiAgentPlugin +from strands.plugins.multiagent_registry import _MultiAgentPluginRegistry +from strands.plugins.registry import _PluginRegistry + +# --- Fixtures --- + + +@pytest.fixture +def mock_orchestrator(): + """Create a mock orchestrator with a working hook registry.""" + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + orch.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: orch.hooks.add_callback(event_type, callback) + ) + return orch + + +@pytest.fixture +def registry(mock_orchestrator): + """Create a _MultiAgentPluginRegistry backed by the mock orchestrator.""" + return _MultiAgentPluginRegistry(mock_orchestrator) + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a working hook registry for dual-plugin tests.""" + agent = unittest.mock.MagicMock() + agent.hooks = HookRegistry() + agent.add_hook = unittest.mock.Mock( + side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback) + ) + agent.tool_registry = unittest.mock.MagicMock() + return agent + + +# --- MultiAgentPlugin base class tests --- + + +def test_multiagent_plugin_is_class_not_protocol(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert isinstance(MyPlugin(), MultiAgentPlugin) + + +def test_multiagent_plugin_requires_name_attribute(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().name == "my-plugin" + + +def test_multiagent_plugin_name_as_property(): + class MyPlugin(MultiAgentPlugin): + @property + def name(self) -> str: + return "property-plugin" + + assert MyPlugin().name == "property-plugin" + + +def test_multiagent_plugin_requires_name(): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + + class PluginWithoutName(MultiAgentPlugin): + def init_multi_agent(self, orchestrator): + pass + + PluginWithoutName() + + +def test_multiagent_plugin_provides_default_init_multi_agent(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert MyPlugin().init_multi_agent(unittest.mock.MagicMock()) is None + + +# --- Auto-discovery tests --- + + +def test_discovers_hook_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_node" + + +def test_discovers_multiple_hooks(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeNodeCallEvent): + pass + + @hook + def hook2(self, event: AfterNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"hook1", "hook2"} + + +def test_hooks_preserve_definition_order(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def z_last(self, event: BeforeNodeCallEvent): + pass + + @hook + def a_first(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert [h.__name__ for h in plugin.hooks] == ["z_last", "a_first"] + + +def test_ignores_non_decorated_methods(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + + +def test_no_tool_support(): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + assert not hasattr(MyPlugin(), "tools") + + +# --- Registry tests --- + + +def test_registry_add_and_init_calls_init_multi_agent(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.initialized = False + + def init_multi_agent(self, orchestrator): + self.initialized = True + + plugin = TestPlugin() + registry.add_and_init(plugin) + assert plugin.initialized + + +def test_registry_add_duplicate_raises_error(registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + registry.add_and_init(TestPlugin()) + with pytest.raises(ValueError, match="plugin_name= | plugin already registered"): + registry.add_and_init(TestPlugin()) + + +def test_registry_registers_discovered_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_registry_registers_multiple_hooks(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + @hook + def on_after_node(self, event: AfterNodeCallEvent): + pass + + registry.add_and_init(TestPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +def test_registry_async_init_multi_agent_supported(registry): + async_init_called = False + + class AsyncPlugin(MultiAgentPlugin): + name = "async-plugin" + + async def init_multi_agent(self, orchestrator): + nonlocal async_init_called + async_init_called = True + + registry.add_and_init(AsyncPlugin()) + assert async_init_called + + +def test_registry_hooks_are_bound_to_instance(mock_orchestrator, registry): + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.events_received.append(event) + + plugin = TestPlugin() + registry.add_and_init(plugin) + + mock_event = unittest.mock.MagicMock(spec=BeforeNodeCallEvent) + mock_orchestrator.hooks._registered_callbacks[BeforeNodeCallEvent][0](mock_event) + + assert plugin.events_received == [mock_event] + + +def test_registry_raises_reference_error_after_orchestrator_collected(): + orch = unittest.mock.MagicMock() + orch.hooks = HookRegistry() + reg = _MultiAgentPluginRegistry(orch) + del orch + gc.collect() + + with pytest.raises(ReferenceError, match="Orchestrator has been garbage collected"): + _ = reg._orchestrator + + +def test_registry_init_multi_agent_called_before_hook_registration(mock_orchestrator): + call_order = [] + + class TestPlugin(MultiAgentPlugin): + name = "test-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + call_order.append("init") + + original = mock_orchestrator.hooks.add_callback + + def tracking(event_type, callback): + call_order.append("hook") + return original(event_type, callback) + + mock_orchestrator.hooks.add_callback = tracking + + registry = _MultiAgentPluginRegistry(mock_orchestrator) + registry.add_and_init(TestPlugin()) + + assert call_order == ["init", "hook"] + + +# --- Union type tests --- + + +def test_registers_hook_for_union_types(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_node_events(self, event: BeforeNodeCallEvent | AfterNodeCallEvent): + pass + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Subclass override tests --- + + +def test_subclass_can_override_init_multi_agent(mock_orchestrator, registry): + custom_init_called = False + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + nonlocal custom_init_called + custom_init_called = True + + registry.add_and_init(MyPlugin()) + assert custom_init_called + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_subclass_can_add_manual_hooks_in_init(mock_orchestrator, registry): + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeNodeCallEvent): + pass + + def manual_hook(self, event: AfterNodeCallEvent): + pass + + def init_multi_agent(self, orchestrator): + orchestrator.hooks.add_callback(AfterNodeCallEvent, self.manual_hook) + + registry.add_and_init(MyPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(AfterNodeCallEvent, [])) == 1 + + +# --- Inheritance tests --- + + +def test_child_inherits_parent_hooks(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def parent_hook(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def child_hook(self, event: AfterNodeCallEvent): + pass + + plugin = ChildPlugin() + assert len(plugin.hooks) == 2 + assert {h.__name__ for h in plugin.hooks} == {"parent_hook", "child_hook"} + + +def test_child_can_override_parent_hook(): + class ParentPlugin(MultiAgentPlugin): + name = "parent-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + class ChildPlugin(ParentPlugin): + name = "child-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(ChildPlugin().hooks) == 1 + + +# --- Dual plugin tests --- + + +def test_dual_plugin_isinstance_checks(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + plugin = DualPlugin() + assert isinstance(plugin, Plugin) + assert isinstance(plugin, MultiAgentPlugin) + + +def test_dual_plugin_discovers_hooks_once(): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + assert len(DualPlugin().hooks) == 1 + + +def test_dual_plugin_discover_hooks_called_once(monkeypatch): + """Verify the hasattr guard prevents discover_hooks from running twice in dual inheritance.""" + import strands.plugins.plugin as plugin_mod + + call_count = 0 + original = plugin_mod.discover_hooks + + def counting_discover_hooks(instance, plugin_name): + nonlocal call_count + call_count += 1 + return original(instance, plugin_name) + + monkeypatch.setattr(plugin_mod, "discover_hooks", counting_discover_hooks) + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + pass + + DualPlugin() + # Plugin.__init__ calls discover_hooks once; MultiAgentPlugin.__init__ skips due to hasattr guard + assert call_count == 1 + + +def test_dual_plugin_has_both_init_methods(mock_agent, mock_orchestrator): + agent_init_called = False + multi_agent_init_called = False + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def init_agent(self, agent): + nonlocal agent_init_called + agent_init_called = True + + def init_multi_agent(self, orchestrator): + nonlocal multi_agent_init_called + multi_agent_init_called = True + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert agent_init_called + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert multi_agent_init_called + + +def test_dual_plugin_registers_hooks_in_both_contexts(mock_agent, mock_orchestrator): + from strands.hooks import BeforeModelCallEvent + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + pass + + @hook + def on_node_call(self, event: BeforeNodeCallEvent): + pass + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_orchestrator.hooks._registered_callbacks.get(BeforeNodeCallEvent, [])) == 1 + + +def test_dual_plugin_shared_state(mock_agent, mock_orchestrator): + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + def __init__(self): + super().__init__() + self.call_count = 0 + + @hook + def on_before_node(self, event: BeforeNodeCallEvent): + self.call_count += 1 + + def init_agent(self, agent): + self.call_count += 10 + + def init_multi_agent(self, orchestrator): + self.call_count += 100 + + plugin = DualPlugin() + _PluginRegistry(mock_agent).add_and_init(plugin) + assert plugin.call_count == 10 + + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(plugin) + assert plugin.call_count == 110 + + +def test_dual_plugin_tools_only_for_agent(mock_agent, mock_orchestrator): + from strands.tools.decorator import tool + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + _PluginRegistry(mock_agent).add_and_init(DualPlugin()) + mock_agent.tool_registry.process_tools.assert_called_once() + + # Orchestrator has no tool registration + _MultiAgentPluginRegistry(mock_orchestrator).add_and_init(DualPlugin()) + + +# --- Double-discovery guard tests --- + + +def test_dual_plugin_hasattr_guard_prevents_double_discovery(): + """Test that the hasattr guard in __init__ prevents hooks from being discovered twice.""" + + class DualPlugin(Plugin, MultiAgentPlugin): + name = "dual-plugin" + + @hook + def shared_hook(self, event: BeforeNodeCallEvent): + pass + + plugin = DualPlugin() + # If double-discovery occurred, we'd see 2 hooks instead of 1 + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "shared_hook" + + +def test_multiagent_plugin_hasattr_guard_with_pre_set_hooks(): + """Test that MultiAgentPlugin.__init__ skips discovery if _hooks already set.""" + + class MyPlugin(MultiAgentPlugin): + name = "my-plugin" + + def __init__(self): + # Pre-set _hooks before super().__init__ + self._hooks = [] + super().__init__() + + @hook + def should_not_be_discovered(self, event: BeforeNodeCallEvent): + pass + + plugin = MyPlugin() + # The guard should have skipped discovery since _hooks was already set + assert len(plugin.hooks) == 0 From 1232230daa4385fd6470be29013eb2375d7a307c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 15 May 2026 10:46:33 -0400 Subject: [PATCH 276/279] feat: bump starlette dependency to 1.x (#2297) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a017cd07..cdc09fe45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,8 +68,8 @@ a2a = [ "a2a-sdk[sql]>=0.3.0,<0.4.0", "uvicorn>=0.34.2,<1.0.0", "httpx>=0.28.1,<1.0.0", - "fastapi>=0.115.12,<1.0.0", - "starlette>=0.46.2,<1.0.0", + "fastapi>=0.133.0,<1.0.0", + "starlette>=1.0.0,<2.0.0", ] bidi = [ From 46ce50b3027ab34dc969a08cbba3ea57f0accdcf Mon Sep 17 00:00:00 2001 From: Kien Pham <22681+kpx-dev@users.noreply.github.com> Date: Tue, 19 May 2026 10:54:04 -0700 Subject: [PATCH 277/279] feat(bedrock): add TTL support to auto-injected tool and system/user cache points (#2232) --- src/strands/models/__init__.py | 3 +- src/strands/models/bedrock.py | 38 +++++-- src/strands/models/model.py | 16 +++ tests/strands/models/test_bedrock.py | 68 +++++++++++- tests_integ/models/test_model_bedrock.py | 135 ++++++++++++++++++++++- 5 files changed, 248 insertions(+), 12 deletions(-) diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 3a23e257a..8ae660da0 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -7,7 +7,7 @@ from . import bedrock, model from .bedrock import BedrockModel -from .model import BaseModelConfig, CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model __all__ = [ "bedrock", @@ -15,6 +15,7 @@ "BaseModelConfig", "BedrockModel", "CacheConfig", + "CacheToolsConfig", "Model", ] diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ab9adb67a..4cd6f7fbc 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -34,7 +34,7 @@ from ._defaults import resolve_config_metadata from ._strict_schema import ensure_strict_json_schema from ._validation import validate_config_keys -from .model import BaseModelConfig, CacheConfig, Model +from .model import BaseModelConfig, CacheConfig, CacheToolsConfig, Model logger = logging.getLogger(__name__) @@ -90,7 +90,8 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt (deprecated, use cache_config) cache_config: Configuration for prompt caching. Use CacheConfig(strategy="auto") for automatic caching. - cache_tools: Cache point type for tools + cache_tools: Cache point type for tools. Pass a string (e.g. "default") for the default 5m TTL, + or a CacheToolsConfig instance to set both type and TTL (e.g. "1h"). guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -127,7 +128,7 @@ class BedrockConfig(BaseModelConfig, total=False): additional_response_field_paths: list[str] | None cache_prompt: str | None cache_config: CacheConfig | None - cache_tools: str | None + cache_tools: str | CacheToolsConfig | None guardrail_id: str | None guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None guardrail_stream_processing_mode: Literal["sync", "async"] | None @@ -292,11 +293,7 @@ def _format_request( } for tool_spec in tool_specs ], - *( - [{"cachePoint": {"type": self.config["cache_tools"]}}] - if self.config.get("cache_tools") - else [] - ), + *self._build_tools_cache_point(), ], **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } @@ -371,6 +368,25 @@ def _get_additional_request_fields(self, tool_choice: ToolChoice | None) -> dict return {"additionalModelRequestFields": additional_fields} + def _build_tools_cache_point(self) -> list[dict[str, Any]]: + """Build the cache point block appended to ``toolConfig.tools`` if ``cache_tools`` is configured. + + Returns: + A single-element list containing the cache point block, or an empty list if no cache_tools is set. + """ + cache_tools = self.config.get("cache_tools") + if not cache_tools: + return [] + + if isinstance(cache_tools, CacheToolsConfig): + cache_point: dict[str, Any] = {"type": cache_tools.type} + if cache_tools.ttl: + cache_point["ttl"] = cache_tools.ttl + else: + cache_point = {"type": cache_tools} + + return [{"cachePoint": cache_point}] + def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: """Inject a cache point at the end of the last user message. @@ -395,7 +411,11 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: last_user_idx = msg_idx if last_user_idx is not None and messages[last_user_idx].get("content"): - messages[last_user_idx]["content"].append({"cachePoint": {"type": "default"}}) + cache_point: dict[str, Any] = {"type": "default"} + cache_config = self.config.get("cache_config") + if cache_config and cache_config.ttl: + cache_point["ttl"] = cache_config.ttl + messages[last_user_idx]["content"].append({"cachePoint": cache_point}) logger.debug("msg_idx=<%s> | added cache point to last user message", last_user_idx) def _find_last_user_text_message_index(self, messages: Messages) -> int | None: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index dd2f9eed2..77ef1df40 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -134,9 +134,25 @@ class CacheConfig: strategy: Caching strategy to use. - "auto": Automatically detect model support and inject cachePoint to maximize cache coverage - "anthropic": Inject cachePoint in Anthropic-compatible format without model support check + ttl: Optional TTL duration for cache entries (e.g. "5m", "1h"). + When specified, auto-injected cache points will include this TTL value. """ strategy: Literal["auto", "anthropic"] = "auto" + ttl: str | None = None + + +@dataclass +class CacheToolsConfig: + """Configuration for the toolConfig cache point. + + Attributes: + type: Cache point type (e.g. "default"). + ttl: Optional TTL duration for the cache entry (e.g. "5m", "1h"). + """ + + type: str = "default" + ttl: str | None = None class Model(abc.ABC): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2e105d64a..319b5574f 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -14,7 +14,7 @@ import strands from strands import _exception_notes -from strands.models import BedrockModel, CacheConfig +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.models.bedrock import ( DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION, @@ -3554,3 +3554,69 @@ async def test_skip_native_api_by_default(self, bedrock_client, model_id, messag bedrock_client.count_tokens.assert_not_called() assert isinstance(result, int) assert result >= 0 + + +def test_inject_cache_point_with_ttl(bedrock_client): + """Test that _inject_cache_point includes TTL when cache_config has ttl set.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert cache_point["ttl"] == "5m" + + +def test_inject_cache_point_without_ttl(bedrock_client): + """Test that _inject_cache_point omits TTL when cache_config has no ttl.""" + model = BedrockModel( + model_id="us.anthropic.claude-sonnet-4-20250514-v1:0", + cache_config=CacheConfig(strategy="auto"), + ) + + cleaned_messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + model._inject_cache_point(cleaned_messages) + + cache_point = cleaned_messages[0]["content"][-1]["cachePoint"] + assert cache_point["type"] == "default" + assert "ttl" not in cache_point + + +def test_format_request_cache_tools_config_with_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig propagates type and ttl into toolConfig cachePoint.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type, ttl="5m")) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type, "ttl": "5m"}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_config_without_ttl(model, messages, model_id, tool_spec, cache_type): + """Test that CacheToolsConfig without ttl produces a cachePoint with only type.""" + model.update_config(cache_tools=CacheToolsConfig(type=cache_type)) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point + + +def test_format_request_cache_tools_string_backward_compat(model, messages, model_id, tool_spec, cache_type): + """Test that passing cache_tools as a string still produces a cachePoint with only type.""" + model.update_config(cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + + exp_cache_point = {"cachePoint": {"type": cache_type}} + assert tru_request["toolConfig"]["tools"][-1] == exp_cache_point diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 06c72ef88..509a300f3 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -6,9 +6,18 @@ import strands from strands import Agent -from strands.models import BedrockModel +from strands.models import BedrockModel, CacheConfig, CacheToolsConfig from strands.types.content import ContentBlock +# Model ID used for prompt-caching TTL integration tests. Per +# https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html +# the models that officially support 1h TTL on CachePoint are Claude Opus 4.5, +# Claude Haiku 4.5, and Claude Sonnet 4.5. Haiku 4.5 is the newest Haiku +# available and is preferred for CI due to lower latency and cost relative to +# the same-version Sonnet 4.5. Bump this when a newer Haiku is released that +# supports CachePoint TTL. +_CACHE_TTL_MODEL_ID = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + @pytest.fixture def system_prompt(): @@ -561,3 +570,127 @@ def calculator(expression: str) -> float: agent('Search for "python" with tags ["programming", "language"] using the search tool.') assert "search" in tools_called + + +def test_prompt_caching_cache_tools_ttl(): + """Test that CacheToolsConfig(ttl=...) propagates into the auto-injected toolConfig cache point. + + Verifies that BedrockModel(cache_tools=CacheToolsConfig(type="default", ttl="5m")) produces a + Bedrock request with cachePoint.ttl on the toolConfig checkpoint, and that the call + completes without a ValidationException on the TTL field. + + Note: we intentionally do not assert specific cacheWriteInputTokens on the toolConfig + prefix because Bedrock's tool-prefix cache threshold varies by model and region. + The critical behavior under test here is that the TTL field is accepted end-to-end. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + (Claude Opus 4.5, Claude Haiku 4.5, and Claude Sonnet 4.5 all support 1h TTL). + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="5m"), + ) + + @strands.tool + def lookup_fact(topic: str) -> str: + """Look up a fact about the given topic. + + This tool is useful when you need authoritative information. + """ + return f"Fact about {topic}: example" + + agent = Agent( + model=model, + tools=[lookup_fact], + load_tools_from_directory=False, + ) + + # The call must succeed — Bedrock must accept cachePoint.ttl on the toolConfig checkpoint + # without raising a ValidationException. + result = agent("Use the lookup_fact tool to look up 'python'.") + assert len(str(result)) > 0 + + +def test_prompt_caching_cache_config_auto_with_ttl(): + """Test that CacheConfig(strategy="auto", ttl="5m") propagates TTL to the auto-injected message cache point. + + Verifies that the cache point appended to the last user message by _inject_cache_point + carries the configured TTL, and that Bedrock accepts the request. + + Uses Claude Haiku 4.5 which supports TTL in CachePointBlock on Bedrock per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_config=CacheConfig(strategy="auto", ttl="5m"), + ) + + unique_id = str(uuid.uuid4()) + # Minimum 4096 tokens required for caching with Haiku 4.5 + large_message = f"Context for test {unique_id}: " + ("This is important context. " * 1000) + " What is 2+2?" + + agent = Agent( + model=model, + load_tools_from_directory=False, + ) + + # First call: auto-injected cache point on the last user message must include ttl and be accepted + result1 = agent(large_message) + assert len(str(result1)) > 0 + + # Verify cache write occurred with auto-inject + ttl + assert result1.metrics.accumulated_usage.get("cacheWriteInputTokens", 0) > 0, ( + "Expected cacheWriteInputTokens > 0 with CacheConfig(strategy='auto', ttl='5m')" + ) + + +def test_prompt_caching_aligned_1h_ttl_across_checkpoints(): + """Regression test for Bedrock TTL non-increasing ordering rule (Issue #2121). + + Bedrock processes cache checkpoints in order: toolConfig -> system -> messages, + and requires TTLs to be non-increasing. Before this change, cache_tools hardcoded + an implicit 5m TTL, so any 1h TTL on a later checkpoint would raise a + ValidationException. + + This test sets 1h TTL on all three checkpoints simultaneously and verifies the + call succeeds. + + Uses Claude Haiku 4.5 which supports 1h TTL per + https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + """ + model = BedrockModel( + model_id=_CACHE_TTL_MODEL_ID, + streaming=False, + cache_tools=CacheToolsConfig(type="default", ttl="1h"), + cache_config=CacheConfig(strategy="auto", ttl="1h"), + ) + + # Timestamp-based uniqueness to avoid cache conflicts across CI runs + unique_id = str(int(time.time() * 1000000)) + large_context = f"Background context for test {unique_id}: " + ("This is important context. " * 1000) + + # User-supplied 1h cache point on system prompt — third checkpoint also at 1h + system_prompt_with_cache = [ + {"text": large_context}, + {"cachePoint": {"type": "default", "ttl": "1h"}}, + {"text": "You are a helpful assistant."}, + ] + + @strands.tool + def echo(value: str) -> str: + """Echo the given value back.""" + return value + + agent = Agent( + model=model, + system_prompt=system_prompt_with_cache, + tools=[echo], + load_tools_from_directory=False, + ) + + # Must succeed without ValidationException on the non-increasing TTL rule + result = agent("What is 2+2?") + assert len(str(result)) > 0 From 64a6862b8c3053368544a78d4c1c6c31d560d258 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Thu, 21 May 2026 11:18:42 -0400 Subject: [PATCH 278/279] fix(tests): add use_native_token_count=True when expected (#2311) --- tests_integ/models/test_model_bedrock.py | 2 +- tests_integ/models/test_model_gemini.py | 1 + tests_integ/models/test_model_openai.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests_integ/models/test_model_bedrock.py b/tests_integ/models/test_model_bedrock.py index 509a300f3..c1c4adf6f 100644 --- a/tests_integ/models/test_model_bedrock.py +++ b/tests_integ/models/test_model_bedrock.py @@ -516,7 +516,7 @@ def test_prompt_caching_backward_compatibility_no_ttl(): class TestCountTokens: @pytest.fixture def model(self): - return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0") + return BedrockModel(model_id="anthropic.claude-sonnet-4-20250514-v1:0", use_native_token_count=True) @pytest.fixture def messages(self): diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index ac1943382..1057757da 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -227,6 +227,7 @@ def model(self): return GeminiModel( model_id="gemini-2.0-flash", client_args={"api_key": os.environ["GOOGLE_API_KEY"]}, + use_native_token_count=True, ) @pytest.fixture diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index bef526427..6011f2b71 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -408,6 +408,7 @@ def model(self): return OpenAIResponsesModel( model_id="gpt-4o", client_args={"api_key": os.environ["OPENAI_API_KEY"]}, + use_native_token_count=True, ) @pytest.fixture From f6c3b571eda8e5ae2eeb3c997db5d1f7bc2ed986 Mon Sep 17 00:00:00 2001 From: Liz <91279165+lizradway@users.noreply.github.com> Date: Fri, 22 May 2026 15:22:51 -0400 Subject: [PATCH 279/279] fix(tests): fix flaky tests to accept string or number (#2319) --- tests_integ/models/test_model_mantle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests_integ/models/test_model_mantle.py b/tests_integ/models/test_model_mantle.py index 7cc032146..9a432d993 100644 --- a/tests_integ/models/test_model_mantle.py +++ b/tests_integ/models/test_model_mantle.py @@ -40,13 +40,13 @@ def test_chat_completions_agent_invoke(chat_completions_model): """OpenAIModel (Chat Completions) reaches Mantle via bedrock_mantle_config.""" agent = Agent(model=chat_completions_model, system_prompt="Reply in one short sentence.", callback_handler=None) result = agent("What is 2+2?") - assert "4" in str(result) + assert "4" in str(result) or "four" in str(result).lower() def test_agent_invoke(model): agent = Agent(model=model, system_prompt="Reply in one short sentence.", callback_handler=None) result = agent("What is 2+2?") - assert "4" in str(result) + assert "4" in str(result) or "four" in str(result).lower() def test_responses_server_side_conversation(stateful_model):