Skip to content

Commit 1458085

Browse files
authored
Don't require input schema in subagents (splunk#30)
1 parent 257fe31 commit 1458085

2 files changed

Lines changed: 58 additions & 4 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,24 @@ def _denormalize_agent_name(name: str) -> str:
218218

219219

220220
def _agent_as_tool(agent: BaseAgent[OutputT]):
221-
assert agent.name, "Agent must have a name to be used by other Agents"
222-
assert agent.input_schema, (
223-
"Agent must have an input schema to be used by other Agents"
224-
)
221+
if not agent.name:
222+
raise AssertionError("Agent must have a name to be used by other Agents")
223+
224+
# TODO: we should enforce uniqueness of subagent names.
225+
226+
if agent.input_schema is None:
227+
228+
async def _run(content: str) -> str:
229+
result = await agent.invoke([HumanMessage(content=content)])
230+
assert agent.output_schema is None
231+
return result.messages[-1].content
232+
233+
return StructuredTool.from_function(
234+
coroutine=_run,
235+
name=_normalize_agent_name(agent.name),
236+
description=agent.description,
237+
infer_schema=True,
238+
)
225239

226240
InputSchema = agent.input_schema
227241

tests/integration/ai/test_agent.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,46 @@ class NicknameGeneratorInput(BaseModel):
262262
assert subagent_message, "No subagent message found in response"
263263
assert "Chris-zilla" in response, "Agent did generate valid nickname"
264264

265+
@pytest.mark.asyncio
266+
async def test_subagent_without_input_schema(self):
267+
pytest.importorskip("langchain_openai")
268+
model = OpenAIModel(
269+
model="ministral-3:8b",
270+
base_url=OPENAI_BASE_URL,
271+
api_key=OPENAI_API_KEY,
272+
temperature=0.0,
273+
)
274+
275+
async with (
276+
Agent(
277+
model=model,
278+
system_prompt=(
279+
"You are a helpful assistant that generates nicknames"
280+
"If prompted for nickname you MUST append '-zilla' to provided name to create nickname."
281+
"Remember the dash and lowercase zilla. Example: Stefan -> Stefan-zilla"
282+
),
283+
service=self.service,
284+
name="NicknameGeneratorAgent",
285+
description="Generates nicknames for people. Pass a name and get a nickname",
286+
) as subagent,
287+
Agent(
288+
model=model,
289+
system_prompt="You are a supervisor agent that MUST use other agents",
290+
agents=[subagent],
291+
service=self.service,
292+
) as supervisor,
293+
):
294+
result = await supervisor.invoke(
295+
[
296+
HumanMessage(
297+
content="hi, my name is Chris. Generate a nickname for me",
298+
)
299+
]
300+
)
301+
302+
response = result.messages[-1].content
303+
assert "Chris-zilla" in response, "Agent did generate valid nickname"
304+
265305
@pytest.mark.asyncio
266306
async def test_agent_understands_other_agents(self):
267307
pytest.importorskip("langchain_openai")

0 commit comments

Comments
 (0)