Skip to content

Commit 1bae6e2

Browse files
authored
Fix middleware state changes (#122)
1 parent 3026733 commit 1bae6e2

2 files changed

Lines changed: 119 additions & 1 deletion

File tree

splunklib/ai/engines/langchain.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,15 @@ def _convert_model_request_to_lc(
837837
request: ModelRequest,
838838
original_request: LC_ModelRequest,
839839
) -> LC_ModelRequest:
840+
state = _convert_agent_state_to_lc(request.state)
841+
# LC_ModelRequest has `messages` and `state` as independent fields.
842+
# LangChain uses `messages` (not state["messages"]) when calling the LLM,
843+
# so we must override both to ensure middleware mutations (e.g. PII
844+
# redaction) actually reach the model.
840845
return original_request.override(
841846
system_message=LC_SystemMessage(content=request.system_message),
842-
state=_convert_agent_state_to_lc(request.state),
847+
messages=state["messages"],
848+
state=state,
843849
)
844850

845851

tests/integration/ai/test_middleware.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# under the License.
1414

1515
import os
16+
from dataclasses import replace
1617
from typing import Any, override
1718
from unittest.mock import patch
1819

@@ -686,6 +687,117 @@ async def test_middleware(
686687
]
687688
)
688689

690+
@pytest.mark.asyncio
691+
async def test_model_middleware_message_mutation_reaches_llm(self) -> None:
692+
pytest.importorskip("langchain_openai")
693+
694+
# Regression test for DVPL-13038: message mutations in model middleware must reach the LLM.
695+
696+
@model_middleware
697+
async def mutating_middleware(
698+
request: ModelRequest, handler: ModelMiddlewareHandler
699+
) -> ModelResponse:
700+
new_state = replace(
701+
request.state,
702+
response=replace(
703+
request.state.response,
704+
messages=[HumanMessage(content="What is the capital of France?")],
705+
),
706+
)
707+
return await handler(replace(request, state=new_state))
708+
709+
async with Agent(
710+
model=await self.model(),
711+
system_prompt="You are a geography assistant. Answer concisely.",
712+
service=self.service,
713+
middleware=[mutating_middleware],
714+
) as agent:
715+
res = await agent.invoke(
716+
[HumanMessage(content="What is the capital of Germany?")]
717+
)
718+
assert "Paris" in res.final_message.content
719+
720+
@patch(
721+
"splunklib.ai.agent._testing_local_tools_path",
722+
os.path.join(os.path.dirname(__file__), "testdata", "weather.py"),
723+
)
724+
@patch("splunklib.ai.agent._testing_app_id", "app_id")
725+
@pytest.mark.asyncio
726+
async def test_tool_middleware_arg_mutation_reaches_tool(self) -> None:
727+
pytest.importorskip("langchain_openai")
728+
729+
# Tool call arg mutations in tool_middleware must reach the actual tool execution.
730+
731+
@tool_middleware
732+
async def mutating_middleware(
733+
request: ToolRequest, handler: ToolMiddlewareHandler
734+
) -> ToolResponse:
735+
mutated = replace(
736+
request,
737+
call=replace(request.call, args={"city": "Krakow"}),
738+
)
739+
return await handler(mutated)
740+
741+
async with Agent(
742+
model=await self.model(),
743+
system_prompt=(
744+
"You are a helpful assistant. "
745+
"You MUST use available tools when asked about weather."
746+
),
747+
service=self.service,
748+
middleware=[mutating_middleware],
749+
tool_settings=ToolSettings(local=True, remote=None),
750+
) as agent:
751+
res = await agent.invoke(
752+
[HumanMessage(content="What is the weather like today in Berlin?")]
753+
)
754+
# Berlin returns 22.1C; Krakow returns 31.5C
755+
assert "31.5" in res.final_message.content
756+
757+
@pytest.mark.asyncio
758+
async def test_subagent_middleware_arg_mutation_reaches_subagent(self) -> None:
759+
pytest.importorskip("langchain_openai")
760+
761+
# Subagent call arg mutations in subagent_middleware must reach the actual subagent.
762+
763+
class NicknameGeneratorInput(BaseModel):
764+
name: str = Field(description="The person's full name", min_length=1)
765+
766+
@subagent_middleware
767+
async def mutating_middleware(
768+
request: SubagentRequest, handler: SubagentMiddlewareHandler
769+
) -> SubagentResponse:
770+
mutated = replace(
771+
request,
772+
call=replace(request.call, args={"name": "Alice"}),
773+
)
774+
return await handler(mutated)
775+
776+
async with (
777+
Agent(
778+
model=await self.model(),
779+
system_prompt=(
780+
"You are a helpful assistant that generates nicknames. A valid "
781+
"nickname consists of the provided name suffixed with '-zilla.'"
782+
),
783+
service=self.service,
784+
name="NicknameGeneratorAgent",
785+
description="Pass a name and get a nickname",
786+
input_schema=NicknameGeneratorInput,
787+
) as subagent,
788+
Agent(
789+
model=await self.model(),
790+
system_prompt="You are a supervisor agent that MUST use other agents",
791+
agents=[subagent],
792+
service=self.service,
793+
middleware=[mutating_middleware],
794+
) as supervisor,
795+
):
796+
result = await supervisor.invoke(
797+
[HumanMessage(content="Generate a nickname for Bob")]
798+
)
799+
assert "Alice-zilla" in result.final_message.content
800+
689801
@pytest.mark.asyncio
690802
async def test_model_middleware_structured_output(self) -> None:
691803
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)