Skip to content

Commit 9ef8333

Browse files
committed
Make id field of tool/subagent/output calls non-None
Also while here, move the id field to be first.
1 parent 9e45daf commit 9ef8333

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

splunklib/ai/engines/langchain.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,45 @@ def unpack_tool_call(self, call: LC_ToolCall) -> LC_ToolCall:
481481

482482
return call
483483

484+
class _CheckCallIDMiddleware(LC_AgentMiddleware):
485+
def _check_has_call_id(self, msg: LC_AIMessage) -> None:
486+
for call in msg.tool_calls:
487+
if call["id"] is None:
488+
# If we ever hit this with real model, just generate a random call_id here.
489+
raise Exception("LLM returned a Tool Call without a call_id")
490+
491+
@override
492+
async def awrap_model_call(
493+
self,
494+
request: LC_ModelRequest,
495+
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
496+
) -> LC_ModelCallResult:
497+
try:
498+
resp = await handler(request)
499+
ai_message = resp
500+
if isinstance(ai_message, LC_ExtendedModelResponse):
501+
ai_message = ai_message.model_response
502+
if isinstance(ai_message, LC_ModelResponse):
503+
ai_message = next(
504+
(
505+
m
506+
for m in ai_message.result
507+
if isinstance(m, LC_AIMessage)
508+
),
509+
None,
510+
)
511+
assert ai_message, "AIMessage not found found in response"
512+
self._check_has_call_id(ai_message)
513+
return resp
514+
except LC_StructuredOutputError as e:
515+
self._check_has_call_id(e.ai_message)
516+
raise
517+
484518
lc_middleware.append(_ToolFailureArtifact())
485519
if len(conversational_subagents) > 0:
486520
lc_middleware.append(_ThreadIDMiddleware())
487521
lc_middleware.append(_SubagentArgumentPacker())
522+
lc_middleware.append(_CheckCallIDMiddleware())
488523

489524
class _DEBUGMiddleware(LC_AgentMiddleware):
490525
@override
@@ -780,6 +815,9 @@ async def next(r: SubagentRequest) -> SubagentResponse:
780815

781816
return invoke
782817

818+
def _raise_on_missing_call_id(self, msg: AIMessage) -> None:
819+
pass
820+
783821
@override
784822
async def awrap_model_call(
785823
self,
@@ -1254,7 +1292,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12541292
StructuredOutputCall(
12551293
name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
12561294
args=tc["args"],
1257-
id=tc["id"],
1295+
id=tc["id"] or "",
12581296
)
12591297
for tc in ai_message.tool_calls
12601298
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
@@ -1529,7 +1567,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
15291567
name=_denormalize_agent_name(name),
15301568
args=SubagentLCArgs(**tool_call["args"]).args,
15311569
thread_id=SubagentLCArgs(**tool_call["args"]).thread_id,
1532-
id=tool_call["id"],
1570+
id=tool_call["id"] or "",
15331571
)
15341572

15351573
tool_type: ToolType = (
@@ -1538,7 +1576,7 @@ def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | Subagent
15381576
return ToolCall(
15391577
name=_denormalize_tool_name(name),
15401578
args=tool_call["args"],
1541-
id=tool_call["id"],
1579+
id=tool_call["id"] or "",
15421580
type=tool_type,
15431581
)
15441582

@@ -1567,9 +1605,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15671605
],
15681606
structured_output_calls=[
15691607
StructuredOutputCall(
1608+
tc["id"] or "",
15701609
tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
15711610
tc["args"],
1572-
tc["id"],
15731611
)
15741612
for tc in message.tool_calls
15751613
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)

splunklib/ai/messages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@
2323

2424
@dataclass(frozen=True)
2525
class ToolCall:
26+
id: str
2627
name: str
27-
args: dict[str, Any]
28-
id: str | None # TODO: can be None?
2928
type: ToolType
29+
args: dict[str, Any]
3030

3131

3232
@dataclass(frozen=True)
3333
class SubagentCall:
34+
id: str
3435
name: str
3536
args: str | dict[str, Any]
36-
id: str | None # TODO: can be None?
3737
thread_id: str | None
3838

3939

4040
@dataclass(frozen=True)
4141
class StructuredOutputCall:
42+
id: str
4243
name: str
4344
args: dict[str, Any]
44-
id: str | None # TODO: can be None?
4545

4646

4747
@dataclass(frozen=True)

0 commit comments

Comments
 (0)