|
12 | 12 | # License for the specific language governing permissions and limitations |
13 | 13 | # under the License. |
14 | 14 |
|
| 15 | +from dataclasses import replace |
15 | 16 | import pytest |
16 | 17 | from pydantic import BaseModel, Field |
17 | 18 |
|
18 | 19 | from splunklib.ai import Agent |
19 | | -from splunklib.ai.messages import AIMessage, HumanMessage, SubagentCall, SubagentMessage |
| 20 | +from splunklib.ai.messages import ( |
| 21 | + AIMessage, |
| 22 | + HumanMessage, |
| 23 | + SubagentCall, |
| 24 | + SubagentFailureResult, |
| 25 | + SubagentMessage, |
| 26 | +) |
| 27 | +from splunklib.ai.middleware import ( |
| 28 | + ModelMiddlewareHandler, |
| 29 | + ModelRequest, |
| 30 | + ModelResponse, |
| 31 | + SubagentMiddlewareHandler, |
| 32 | + SubagentRequest, |
| 33 | + SubagentResponse, |
| 34 | + model_middleware, |
| 35 | + subagent_middleware, |
| 36 | +) |
20 | 37 | from tests.ai_testlib import AITestCase |
21 | 38 |
|
22 | 39 | OPENAI_BASE_URL = "http://localhost:11434/v1" |
@@ -411,3 +428,92 @@ async def test_duplicated_subagent_name(self) -> None: |
411 | 428 | agents=[subagent1_empty_name, subagent2_empty_name], |
412 | 429 | ): |
413 | 430 | pass |
| 431 | + |
| 432 | + @pytest.mark.asyncio |
| 433 | + async def test_subagent_soft_failure_with_invalid_args(self) -> None: |
| 434 | + pytest.importorskip("langchain_openai") |
| 435 | + |
| 436 | + # Regression test - In case invalid schema is provided to the |
| 437 | + # subagent during execution, we should not fail the entire agent. |
| 438 | + |
| 439 | + class SubagentInput(BaseModel): |
| 440 | + name: str = Field(description="person name", min_length=1) |
| 441 | + |
| 442 | + after_subagent_call = False |
| 443 | + |
| 444 | + @subagent_middleware |
| 445 | + async def _subagent_call_middleware( |
| 446 | + request: SubagentRequest, handler: SubagentMiddlewareHandler |
| 447 | + ) -> SubagentResponse: |
| 448 | + nonlocal after_subagent_call |
| 449 | + |
| 450 | + # Override the arguments, such that are invalid. |
| 451 | + resp = await handler(replace(request, call=replace(request.call, args={}))) |
| 452 | + assert isinstance(resp.result, SubagentFailureResult), ( |
| 453 | + "subagent call did not fail" |
| 454 | + ) |
| 455 | + |
| 456 | + after_subagent_call = True |
| 457 | + return resp |
| 458 | + |
| 459 | + @model_middleware |
| 460 | + async def _model_call_middleware( |
| 461 | + req: ModelRequest, _handler: ModelMiddlewareHandler |
| 462 | + ) -> ModelResponse: |
| 463 | + if after_subagent_call: |
| 464 | + msgs = req.state.response.messages |
| 465 | + assert isinstance(msgs[-1], SubagentMessage) |
| 466 | + assert isinstance(msgs[-1].result, SubagentFailureResult) |
| 467 | + |
| 468 | + return ModelResponse( |
| 469 | + message=AIMessage( |
| 470 | + content="End of the agent loop", |
| 471 | + calls=[], |
| 472 | + ), |
| 473 | + structured_output=None, |
| 474 | + ) |
| 475 | + else: |
| 476 | + return ModelResponse( |
| 477 | + message=AIMessage( |
| 478 | + content="I need to call tools", |
| 479 | + calls=[ |
| 480 | + SubagentCall( |
| 481 | + id="call-1", |
| 482 | + name="NicknameGeneratorAgent", |
| 483 | + args=SubagentInput(name="Chris").model_dump(), |
| 484 | + ) |
| 485 | + ], |
| 486 | + ), |
| 487 | + structured_output=None, |
| 488 | + ) |
| 489 | + |
| 490 | + async with ( |
| 491 | + Agent( |
| 492 | + model=(await self.model()), |
| 493 | + system_prompt=( |
| 494 | + "You are a helpful assistant that generates nicknames" |
| 495 | + "If prompted for nickname you MUST append '-zilla' to provided name to create nickname." |
| 496 | + "Remember the dash and lowercase zilla. Example: Stefan -> Stefan-zilla" |
| 497 | + ), |
| 498 | + service=self.service, |
| 499 | + input_schema=SubagentInput, |
| 500 | + name="NicknameGeneratorAgent", |
| 501 | + description="Generates nicknames for people. Pass a name and get a nickname", |
| 502 | + ) as subagent, |
| 503 | + Agent( |
| 504 | + model=(await self.model()), |
| 505 | + system_prompt="You are a supervisor agent that MUST use other agents", |
| 506 | + agents=[subagent], |
| 507 | + service=self.service, |
| 508 | + middleware=[_subagent_call_middleware, _model_call_middleware], |
| 509 | + ) as supervisor, |
| 510 | + ): |
| 511 | + await supervisor.invoke( |
| 512 | + [ |
| 513 | + HumanMessage( |
| 514 | + content="Hi, my name is Chris. Generate a nickname for me", |
| 515 | + ) |
| 516 | + ] |
| 517 | + ) |
| 518 | + |
| 519 | + assert after_subagent_call, "subagent was not called" |
0 commit comments