|
13 | 13 | # under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +from dataclasses import replace |
16 | 17 | from typing import Any, override |
17 | 18 | from unittest.mock import patch |
18 | 19 |
|
@@ -686,6 +687,117 @@ async def test_middleware( |
686 | 687 | ] |
687 | 688 | ) |
688 | 689 |
|
| 690 | + @pytest.mark.asyncio |
| 691 | + async def test_model_middleware_message_mutation_reaches_llm(self) -> None: |
| 692 | + pytest.importorskip("langchain_openai") |
| 693 | + |
| 694 | + # Regression test for DVPL-13038: message mutations in model middleware must reach the LLM. |
| 695 | + |
| 696 | + @model_middleware |
| 697 | + async def mutating_middleware( |
| 698 | + request: ModelRequest, handler: ModelMiddlewareHandler |
| 699 | + ) -> ModelResponse: |
| 700 | + new_state = replace( |
| 701 | + request.state, |
| 702 | + response=replace( |
| 703 | + request.state.response, |
| 704 | + messages=[HumanMessage(content="What is the capital of France?")], |
| 705 | + ), |
| 706 | + ) |
| 707 | + return await handler(replace(request, state=new_state)) |
| 708 | + |
| 709 | + async with Agent( |
| 710 | + model=await self.model(), |
| 711 | + system_prompt="You are a geography assistant. Answer concisely.", |
| 712 | + service=self.service, |
| 713 | + middleware=[mutating_middleware], |
| 714 | + ) as agent: |
| 715 | + res = await agent.invoke( |
| 716 | + [HumanMessage(content="What is the capital of Germany?")] |
| 717 | + ) |
| 718 | + assert "Paris" in res.final_message.content |
| 719 | + |
| 720 | + @patch( |
| 721 | + "splunklib.ai.agent._testing_local_tools_path", |
| 722 | + os.path.join(os.path.dirname(__file__), "testdata", "weather.py"), |
| 723 | + ) |
| 724 | + @patch("splunklib.ai.agent._testing_app_id", "app_id") |
| 725 | + @pytest.mark.asyncio |
| 726 | + async def test_tool_middleware_arg_mutation_reaches_tool(self) -> None: |
| 727 | + pytest.importorskip("langchain_openai") |
| 728 | + |
| 729 | + # Tool call arg mutations in tool_middleware must reach the actual tool execution. |
| 730 | + |
| 731 | + @tool_middleware |
| 732 | + async def mutating_middleware( |
| 733 | + request: ToolRequest, handler: ToolMiddlewareHandler |
| 734 | + ) -> ToolResponse: |
| 735 | + mutated = replace( |
| 736 | + request, |
| 737 | + call=replace(request.call, args={"city": "Krakow"}), |
| 738 | + ) |
| 739 | + return await handler(mutated) |
| 740 | + |
| 741 | + async with Agent( |
| 742 | + model=await self.model(), |
| 743 | + system_prompt=( |
| 744 | + "You are a helpful assistant. " |
| 745 | + "You MUST use available tools when asked about weather." |
| 746 | + ), |
| 747 | + service=self.service, |
| 748 | + middleware=[mutating_middleware], |
| 749 | + tool_settings=ToolSettings(local=True, remote=None), |
| 750 | + ) as agent: |
| 751 | + res = await agent.invoke( |
| 752 | + [HumanMessage(content="What is the weather like today in Berlin?")] |
| 753 | + ) |
| 754 | + # Berlin returns 22.1C; Krakow returns 31.5C |
| 755 | + assert "31.5" in res.final_message.content |
| 756 | + |
| 757 | + @pytest.mark.asyncio |
| 758 | + async def test_subagent_middleware_arg_mutation_reaches_subagent(self) -> None: |
| 759 | + pytest.importorskip("langchain_openai") |
| 760 | + |
| 761 | + # Subagent call arg mutations in subagent_middleware must reach the actual subagent. |
| 762 | + |
| 763 | + class NicknameGeneratorInput(BaseModel): |
| 764 | + name: str = Field(description="The person's full name", min_length=1) |
| 765 | + |
| 766 | + @subagent_middleware |
| 767 | + async def mutating_middleware( |
| 768 | + request: SubagentRequest, handler: SubagentMiddlewareHandler |
| 769 | + ) -> SubagentResponse: |
| 770 | + mutated = replace( |
| 771 | + request, |
| 772 | + call=replace(request.call, args={"name": "Alice"}), |
| 773 | + ) |
| 774 | + return await handler(mutated) |
| 775 | + |
| 776 | + async with ( |
| 777 | + Agent( |
| 778 | + model=await self.model(), |
| 779 | + system_prompt=( |
| 780 | + "You are a helpful assistant that generates nicknames. A valid " |
| 781 | + "nickname consists of the provided name suffixed with '-zilla.'" |
| 782 | + ), |
| 783 | + service=self.service, |
| 784 | + name="NicknameGeneratorAgent", |
| 785 | + description="Pass a name and get a nickname", |
| 786 | + input_schema=NicknameGeneratorInput, |
| 787 | + ) as subagent, |
| 788 | + Agent( |
| 789 | + model=await self.model(), |
| 790 | + system_prompt="You are a supervisor agent that MUST use other agents", |
| 791 | + agents=[subagent], |
| 792 | + service=self.service, |
| 793 | + middleware=[mutating_middleware], |
| 794 | + ) as supervisor, |
| 795 | + ): |
| 796 | + result = await supervisor.invoke( |
| 797 | + [HumanMessage(content="Generate a nickname for Bob")] |
| 798 | + ) |
| 799 | + assert "Alice-zilla" in result.final_message.content |
| 800 | + |
689 | 801 | @pytest.mark.asyncio |
690 | 802 | async def test_model_middleware_structured_output(self) -> None: |
691 | 803 | pytest.importorskip("langchain_openai") |
|
0 commit comments