Skip to content

Commit eb05701

Browse files
authored
Populate artifact for subagent failures (#116)
1 parent 8181104 commit eb05701

File tree

3 files changed

+112
-12
lines changed

3 files changed

+112
-12
lines changed

splunklib/ai/engines/langchain.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,12 @@ async def awrap_tool_call(
203203
assert resp.name, "missing tool name"
204204

205205
if resp.status == "error":
206-
# This assertion asserts the current behaviour, but can be removed safely,
207-
# if we at some point decide to raise a LC_ToolException in a subagent invocation.
208-
# Also in such case we would need to populate artifact with SubagentFailureResult.
209-
assert not resp.name.startswith(AGENT_PREFIX), (
210-
"subagent produced a non-fatal error"
211-
)
212-
213206
assert resp.artifact is None, "artifact is already populated"
214-
resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
207+
208+
if resp.name.startswith(AGENT_PREFIX):
209+
resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
210+
else:
211+
resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
215212

216213
return resp
217214

splunklib/ai/messages.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ class SubagentFailureResult:
126126
127127
This type of failure is non-fatal, i.e. it does not stop the agent loop.
128128
Instead, the error information is returned to the LLM.
129-
130-
Currently this result is not produced by the subagent call, but can be leveraged
131-
in middlewares e.g. to reject subagent calls in a non-fatal way.
132129
"""
133130

134131
error_message: str

tests/integration/ai/test_agent.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,28 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15+
from dataclasses import replace
1516
import pytest
1617
from pydantic import BaseModel, Field
1718

1819
from splunklib.ai import Agent
19-
from splunklib.ai.messages import AIMessage, HumanMessage, SubagentCall, SubagentMessage
20+
from splunklib.ai.messages import (
21+
AIMessage,
22+
HumanMessage,
23+
SubagentCall,
24+
SubagentFailureResult,
25+
SubagentMessage,
26+
)
27+
from splunklib.ai.middleware import (
28+
ModelMiddlewareHandler,
29+
ModelRequest,
30+
ModelResponse,
31+
SubagentMiddlewareHandler,
32+
SubagentRequest,
33+
SubagentResponse,
34+
model_middleware,
35+
subagent_middleware,
36+
)
2037
from tests.ai_testlib import AITestCase
2138

2239
OPENAI_BASE_URL = "http://localhost:11434/v1"
@@ -411,3 +428,92 @@ async def test_duplicated_subagent_name(self) -> None:
411428
agents=[subagent1_empty_name, subagent2_empty_name],
412429
):
413430
pass
431+
432+
@pytest.mark.asyncio
433+
async def test_subagent_soft_failure_with_invalid_args(self) -> None:
434+
pytest.importorskip("langchain_openai")
435+
436+
# Regression test - In case invalid schema is provided to the
437+
# subagent during execution, we should not fail the entire agent.
438+
439+
class SubagentInput(BaseModel):
440+
name: str = Field(description="person name", min_length=1)
441+
442+
after_subagent_call = False
443+
444+
@subagent_middleware
445+
async def _subagent_call_middleware(
446+
request: SubagentRequest, handler: SubagentMiddlewareHandler
447+
) -> SubagentResponse:
448+
nonlocal after_subagent_call
449+
450+
# Override the arguments, such that are invalid.
451+
resp = await handler(replace(request, call=replace(request.call, args={})))
452+
assert isinstance(resp.result, SubagentFailureResult), (
453+
"subagent call did not fail"
454+
)
455+
456+
after_subagent_call = True
457+
return resp
458+
459+
@model_middleware
460+
async def _model_call_middleware(
461+
req: ModelRequest, _handler: ModelMiddlewareHandler
462+
) -> ModelResponse:
463+
if after_subagent_call:
464+
msgs = req.state.response.messages
465+
assert isinstance(msgs[-1], SubagentMessage)
466+
assert isinstance(msgs[-1].result, SubagentFailureResult)
467+
468+
return ModelResponse(
469+
message=AIMessage(
470+
content="End of the agent loop",
471+
calls=[],
472+
),
473+
structured_output=None,
474+
)
475+
else:
476+
return ModelResponse(
477+
message=AIMessage(
478+
content="I need to call tools",
479+
calls=[
480+
SubagentCall(
481+
id="call-1",
482+
name="NicknameGeneratorAgent",
483+
args=SubagentInput(name="Chris").model_dump(),
484+
)
485+
],
486+
),
487+
structured_output=None,
488+
)
489+
490+
async with (
491+
Agent(
492+
model=(await self.model()),
493+
system_prompt=(
494+
"You are a helpful assistant that generates nicknames"
495+
"If prompted for nickname you MUST append '-zilla' to provided name to create nickname."
496+
"Remember the dash and lowercase zilla. Example: Stefan -> Stefan-zilla"
497+
),
498+
service=self.service,
499+
input_schema=SubagentInput,
500+
name="NicknameGeneratorAgent",
501+
description="Generates nicknames for people. Pass a name and get a nickname",
502+
) as subagent,
503+
Agent(
504+
model=(await self.model()),
505+
system_prompt="You are a supervisor agent that MUST use other agents",
506+
agents=[subagent],
507+
service=self.service,
508+
middleware=[_subagent_call_middleware, _model_call_middleware],
509+
) as supervisor,
510+
):
511+
await supervisor.invoke(
512+
[
513+
HumanMessage(
514+
content="Hi, my name is Chris. Generate a nickname for me",
515+
)
516+
]
517+
)
518+
519+
assert after_subagent_call, "subagent was not called"

0 commit comments

Comments
 (0)