Skip to content

Commit 884a508

Browse files
authored
Drop default values of message types (#89)
1 parent 284d8b8 commit 884a508

File tree

8 files changed

+45
-36
lines changed

8 files changed

+45
-36
lines changed

examples/ai_custom_alert_app/bin/threat_level_assessment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ async def invoke_agent(
100100
) as agent:
101101
logger.info(f"Invoking {agent.model=}")
102102
logger.debug(f"{user_prompt=}")
103-
result = await agent.invoke([HumanMessage(role="user", content=user_prompt)])
103+
result = await agent.invoke([HumanMessage(content=user_prompt)])
104104
return result.structured_output
105105

106106

examples/ai_custom_search_app/bin/agentic_reporting_csc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def invoke_agent(self, prompt: str) -> AgentOutput:
153153
output_schema=AgentOutput,
154154
) as agent:
155155
logger.info(f"Invoking {LLM_MODEL.model} at {LLM_MODEL.base_url}")
156-
result = await agent.invoke([HumanMessage(role="user", content=prompt)])
156+
result = await agent.invoke([HumanMessage(content=prompt)])
157157
return result.structured_output
158158

159159

examples/ai_modinput_app/bin/agentic_weather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async def invoke_agent(self, data_json: str) -> str:
127127
f"Parse {data_json=} into a into a short, human-readable sentence. "
128128
+ "Was it a good day to go outside if you're human?"
129129
)
130-
response = await agent.invoke([HumanMessage(role="user", content=prompt)])
130+
response = await agent.invoke([HumanMessage(content=prompt)])
131131
logger.debug(f"{response=}")
132132
return response.messages[-1].content
133133

splunklib/ai/messages.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class SubagentCall:
3838

3939
@dataclass(frozen=True)
4040
class BaseMessage:
41-
role: str = ""
42-
content: str = field(default="")
41+
role: str = field(init=False)
42+
content: str = field(init=False)
4343

4444
def __post_init__(self) -> None:
4545
if type(self) is BaseMessage:
@@ -58,7 +58,8 @@ class HumanMessage(BaseMessage):
5858
conversation.
5959
"""
6060

61-
role: Literal["user"] = "user"
61+
role: Literal["user"] = field(default="user", init=False)
62+
content: str
6263

6364

6465
@dataclass(frozen=True)
@@ -71,23 +72,23 @@ class AIMessage(BaseMessage):
7172
requesting the Agent to execute.
7273
"""
7374

74-
role: Literal["assistant"] = "assistant"
75-
calls: Sequence[ToolCall | SubagentCall] = field(
76-
default_factory=list[ToolCall | SubagentCall]
77-
)
75+
role: Literal["assistant"] = field(default="assistant", init=False)
76+
content: str
77+
78+
calls: Sequence[ToolCall | SubagentCall]
7879

7980

8081
@dataclass(frozen=True)
8182
class ToolMessage(BaseMessage):
8283
"""ToolMessage represents a response of a tool call"""
8384

84-
# TODO: See if we can remove the defaults - they should always be populated manually
85+
role: Literal["tool"] = field(default="tool", init=False)
86+
content: str
8587

86-
role: Literal["tool"] = "tool"
87-
name: str = field(default="")
88-
call_id: str = field(default="")
89-
status: Literal["success", "error"] = "success"
90-
type: ToolType = ToolType.LOCAL
88+
name: str
89+
type: ToolType
90+
call_id: str
91+
status: Literal["success", "error"]
9192

9293

9394
@dataclass(frozen=True)
@@ -96,7 +97,8 @@ class SystemMessage(BaseMessage):
9697
A message used to prime or control agent behavior.
9798
"""
9899

99-
role: Literal["system"] = "system"
100+
role: Literal["system"] = field(default="system", init=False)
101+
content: str
100102

101103

102104
@dataclass(frozen=True)
@@ -105,10 +107,12 @@ class SubagentMessage(BaseMessage):
105107
SubagentMessage represents a response of an agent invocation
106108
"""
107109

108-
role: Literal["subagent"] = "subagent"
109-
name: str = field(default="")
110-
call_id: str = field(default="")
111-
status: Literal["success", "error"] = "success"
110+
role: Literal["subagent"] = field(default="subagent", init=False)
111+
content: str
112+
113+
name: str
114+
call_id: str
115+
status: Literal["success", "error"]
112116

113117

114118
OutputT = TypeVar("OutputT", default=None, covariant=True, bound=BaseModel | None)
@@ -119,4 +123,4 @@ class AgentResponse(Generic[OutputT]):
119123
# in case output_schema is provided, this will hold the parsed structured output
120124
structured_output: OutputT
121125
# Holds the full message history including tool calls and final response
122-
messages: list[BaseMessage] = field(default_factory=list)
126+
messages: list[BaseMessage]

tests/integration/ai/test_agent_mcp_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ class ToolResults(BaseModel):
715715
assert len(agent.tools) == 2
716716

717717
content = "Call tools to populate output."
718-
response = await agent.invoke([HumanMessage("user", content)])
718+
response = await agent.invoke([HumanMessage(content)])
719719
print(response.structured_output)
720720
assert response.structured_output.remote_temperature == "31.5C"
721721
assert response.structured_output.local_temperature == "22.1C"

tests/integration/ai/test_middleware.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,9 @@ async def test_middleware(
632632
nonlocal middleware_called
633633
middleware_called = True
634634

635-
return ModelResponse(message=AIMessage(content="My response is made up"))
635+
return ModelResponse(
636+
message=AIMessage(content="My response is made up", calls=[])
637+
)
636638

637639
async with Agent(
638640
model=await self.model(),
@@ -741,9 +743,7 @@ async def test_middleware(
741743
_req: ModelRequest, _handler: ModelMiddlewareHandler
742744
) -> ModelResponse:
743745
return ModelResponse(
744-
message=AIMessage(
745-
content="Stefan",
746-
),
746+
message=AIMessage(content="Stefan", calls=[]),
747747
structured_output=Output(name="Stefan"),
748748
)
749749

@@ -803,7 +803,7 @@ async def agent_middleware(
803803
HumanMessage(
804804
content="What is the weather like today in Krakow?"
805805
),
806-
AIMessage(content="Cloudy"),
806+
AIMessage(content="Cloudy", calls=[]),
807807
],
808808
structured_output=None,
809809
)
@@ -854,7 +854,7 @@ async def test_middleware(
854854
return AgentResponse(
855855
messages=[
856856
HumanMessage(content="What is the weather like today in Krakow?"),
857-
AIMessage(content="Cloudy"),
857+
AIMessage(content="Cloudy", calls=[]),
858858
],
859859
structured_output=None,
860860
)
@@ -869,7 +869,7 @@ async def test_middleware(
869869
[HumanMessage(content="What is the weather like today in Krakow?")]
870870
)
871871
assert len(resp.messages) == 2
872-
assert resp.messages[1] == AIMessage(content="Cloudy")
872+
assert resp.messages[1] == AIMessage(content="Cloudy", calls=[])
873873

874874
@pytest.mark.asyncio
875875
async def test_agent_middleware_retry(self) -> None:
@@ -930,7 +930,7 @@ async def test2_middleware(
930930
return AgentResponse(
931931
messages=[
932932
HumanMessage(content="What is the weather like today in Krakow?"),
933-
AIMessage(content="Cloudy"),
933+
AIMessage(content="Cloudy", calls=[]),
934934
],
935935
structured_output=None,
936936
)
@@ -992,7 +992,7 @@ async def test_middleware(
992992
return AgentResponse(
993993
messages=[
994994
HumanMessage(content="What is your name?"),
995-
AIMessage(content="Stefan"),
995+
AIMessage(content="Stefan", calls=[]),
996996
],
997997
structured_output=None,
998998
)
@@ -1027,7 +1027,7 @@ async def test_middleware(
10271027
return AgentResponse[Any | None](
10281028
messages=[
10291029
HumanMessage(content="What is your name?"),
1030-
AIMessage(content="Stefan"),
1030+
AIMessage(content="Stefan", calls=[]),
10311031
],
10321032
structured_output=Output2(name="Stefan"),
10331033
)
@@ -1062,7 +1062,7 @@ async def test_middleware(
10621062
return AgentResponse[Any | None](
10631063
messages=[
10641064
HumanMessage(content="What is your name?"),
1065-
AIMessage(content="Stefan"),
1065+
AIMessage(content="Stefan", calls=[]),
10661066
],
10671067
structured_output=Output(name="Stefan"),
10681068
)

tests/system/test_apps/ai_agentic_test_local_tools_app/bin/agentic_app_tools_endpoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ async def run(self) -> None:
5050
result = await agent.invoke(
5151
[
5252
HumanMessage(
53-
role="user",
5453
content=(
5554
"What is the weather like today in Krakow? Use the provided tools to check the temperature. "
5655
"Return a short response, containing the tool response."

tests/unit/ai/engine/test_langchain_backend.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,13 @@ def test_map_message_to_langchain_tool_call_with_reserved_prefix(self) -> None:
200200
]
201201

202202
message = lc._map_message_to_langchain(
203-
ToolMessage(content="hi", name="__bad-tool", type=ToolType.REMOTE)
203+
ToolMessage(
204+
call_id="foo",
205+
status="success",
206+
content="hi",
207+
name="__bad-tool",
208+
type=ToolType.REMOTE,
209+
)
204210
)
205211
assert isinstance(message, LC_ToolMessage)
206212
assert message.name == "__tool-__bad-tool"

0 commit comments

Comments
 (0)