Skip to content

Commit 505b724

Browse files
authored
Validate messages before and after the agent loop (#142)
1 parent e345584 commit 505b724

File tree

3 files changed

+801
-8
lines changed

3 files changed

+801
-8
lines changed

splunklib/ai/engines/langchain.py

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,11 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
605605
# Prepend messages from conversation store.
606606
if self._sdk_agent.conversation_store:
607607
msgs = await self._sdk_agent.conversation_store.get_messages(thread_id)
608-
langchain_msgs.extend([_map_message_to_langchain(m) for m in msgs])
608+
if len(msgs) > 0:
609+
_validate_messages(msgs, False)
610+
langchain_msgs.extend([_map_message_to_langchain(m) for m in msgs])
609611

612+
_validate_messages(req.messages, False)
610613
langchain_msgs.extend([_map_message_to_langchain(m) for m in req.messages])
611614

612615
while True:
@@ -629,6 +632,9 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
629632

630633
sdk_msgs = [_map_message_from_langchain(m) for m in result["messages"]]
631634

635+
# Serves as an assertion, if this is hit, it likely means a bug in the agentic loop.
636+
_validate_messages(sdk_msgs, True)
637+
632638
# NOTE: Agent responses will always conform to output schema. Verifying
633639
# if an LLM made any mistakes or not is _always_ up to the developer.
634640

@@ -645,8 +651,6 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
645651
else:
646652
resp = AgentResponse(structured_output=None, messages=sdk_msgs)
647653

648-
resp.final_message # serves as an assertion
649-
650654
return resp
651655

652656
result = await self._with_agent_middleware(invoke_agent)(
@@ -659,16 +663,15 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
659663
# not after all were executed?
660664

661665
try:
662-
result.final_message
663-
except AssertionError as e:
664-
raise AssertionError(
665-
f"AgentMiddleware modified AgentResponse.messages and made it invalid: {e}"
666+
_validate_messages(result.messages, True)
667+
except _InvalidMessagesException as e:
668+
raise _InvalidMessagesException(
669+
f"Agent middleware modified messages and made it invalid: {e}"
666670
)
667671

668672
if self._sdk_agent.output_schema:
669673
if result.structured_output is None:
670674
raise AssertionError("Agent middleware discarded a structured output")
671-
672675
if type(result.structured_output) is not self._sdk_agent.output_schema:
673676
raise AssertionError(
674677
f"Agent middleware returned an invalid structured_output type: {type(result.structured_output)}, want: {self._sdk_agent.output_schema}"
@@ -1686,3 +1689,132 @@ def _create_langchain_model(model: PredefinedModel) -> BaseChatModel:
16861689
raise InvalidModelError(
16871690
"Cannot create langchain model - invalid SDK model provided"
16881691
)
1692+
1693+
1694+
class _InvalidMessagesException(Exception):
1695+
pass
1696+
1697+
1698+
def _validate_messages(messages: Sequence[BaseMessage], agent_loop_end: bool) -> None:
1699+
if len(messages) == 0:
1700+
raise _InvalidMessagesException("messages list is empty")
1701+
1702+
pending_structured_calls: dict[str, str] = {}
1703+
pending_tool_calls: dict[str, str] = {}
1704+
pending_subagent_calls: dict[str, str] = {}
1705+
1706+
def check_no_pending_calls() -> None:
1707+
if len(pending_structured_calls) != 0:
1708+
raise _InvalidMessagesException(
1709+
f"StructuredToolCall does not have a corresponding StructuredOutputMessage; ids={list(pending_structured_calls.keys())}"
1710+
)
1711+
if len(pending_tool_calls) != 0:
1712+
raise _InvalidMessagesException(
1713+
f"ToolCall does not have a corresponding ToolMessage; ids={list(pending_tool_calls.keys())}"
1714+
)
1715+
if len(pending_subagent_calls) != 0:
1716+
raise _InvalidMessagesException(
1717+
f"SubagentCall does not have a corresponding SubagentMessage; ids={list(pending_subagent_calls.keys())}"
1718+
)
1719+
1720+
used_call_ids: set[str] = set()
1721+
1722+
def check_call_id(type: str, id: str) -> None:
1723+
if id == "":
1724+
raise _InvalidMessagesException(f"Empty {type} call_id: {id=}")
1725+
if id in used_call_ids:
1726+
raise _InvalidMessagesException(f"Duplicated {type} call_id: {id}")
1727+
1728+
used_call_ids.add(id)
1729+
1730+
def check_tool_name(type: str, name: str) -> None:
1731+
if name == "":
1732+
raise _InvalidMessagesException(f"Empty {type} name: {name=}")
1733+
1734+
# We use `type() is X` instead of `isinstance`/match statement
1735+
# to make sure that users do not subclass our types, since we do
1736+
# type conversions between LC and SDK types in the backend and
1737+
# the subclassed types that users provide would be lost
1738+
# (since we re-create these back as our types).
1739+
1740+
last_ai_message: AIMessage | None = None
1741+
for message in messages:
1742+
if type(message) is HumanMessage:
1743+
check_no_pending_calls()
1744+
elif type(message) is SystemMessage:
1745+
check_no_pending_calls()
1746+
elif type(message) is AIMessage:
1747+
last_ai_message = message
1748+
1749+
check_no_pending_calls()
1750+
for call in message.calls:
1751+
if type(call) is ToolCall:
1752+
assert call.id is not None
1753+
check_call_id("tool", call.id)
1754+
check_tool_name("tool", call.name)
1755+
pending_tool_calls[call.id] = call.name
1756+
elif type(call) is SubagentCall:
1757+
assert call.id is not None
1758+
check_call_id("subagent", call.id)
1759+
check_tool_name("subagent", call.name)
1760+
pending_subagent_calls[call.id] = call.name
1761+
else:
1762+
raise _InvalidMessagesException(
1763+
f"AIMessage contains invalid call type: {type(call)}"
1764+
)
1765+
for call in message.structured_output_calls:
1766+
if type(call) is StructuredOutputCall:
1767+
assert call.id is not None
1768+
check_call_id("structured output tool", call.id)
1769+
check_tool_name("structured output tool", call.name)
1770+
pending_structured_calls[call.id] = call.name
1771+
else:
1772+
raise _InvalidMessagesException(
1773+
f"AIMessage contains invalid call type: {type(call)}"
1774+
)
1775+
1776+
elif type(message) is ToolMessage:
1777+
name = pending_tool_calls.get(message.call_id)
1778+
if name is None:
1779+
raise _InvalidMessagesException(
1780+
f"ToolMessage does not have a corresponding ToolCall; id={message.call_id}"
1781+
)
1782+
if name != message.name:
1783+
raise _InvalidMessagesException(
1784+
f"ToolMessage.name = {message.name}, but the corresponding ToolCall.name = {name}"
1785+
)
1786+
del pending_tool_calls[message.call_id]
1787+
elif type(message) is SubagentMessage:
1788+
name = pending_subagent_calls.get(message.call_id)
1789+
if name is None:
1790+
raise _InvalidMessagesException(
1791+
f"SubagentMessage does not have a corresponding SubagentCall; id={message.call_id}"
1792+
)
1793+
if name != message.name:
1794+
raise _InvalidMessagesException(
1795+
f"SubagentMessage.name = {message.name}, but the corresponding SubagentCall.name = {name}"
1796+
)
1797+
del pending_subagent_calls[message.call_id]
1798+
elif type(message) is StructuredOutputMessage:
1799+
name = pending_structured_calls.get(message.call_id)
1800+
if name is None:
1801+
raise _InvalidMessagesException(
1802+
f"StructuredOutputMessage does not have a corresponding StructuredOutputCall; id={message.call_id}"
1803+
)
1804+
if name != message.name:
1805+
raise _InvalidMessagesException(
1806+
f"StructuredOutputMessage.name = {message.name}, but the corresponding StructuredOutputCall.name = {name}"
1807+
)
1808+
del pending_structured_calls[message.call_id]
1809+
else:
1810+
raise _InvalidMessagesException(
1811+
f"Messages contains invalid message type: {type(message)}"
1812+
)
1813+
1814+
check_no_pending_calls()
1815+
1816+
if agent_loop_end:
1817+
if last_ai_message is None:
1818+
raise _InvalidMessagesException("messages does not have an AIMessage")
1819+
if len(last_ai_message.calls) != 0:
1820+
raise _InvalidMessagesException("last AIMessage has tool calls")

splunklib/ai/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class ToolMessage(BaseMessage):
159159
result: ToolResult | ToolFailureResult
160160

161161

162+
# TODO: do we have a test that uses this?
162163
@dataclass(frozen=True)
163164
class SystemMessage(BaseMessage):
164165
"""

0 commit comments

Comments
 (0)