Skip to content

Commit 806da0b

Browse files
authored
Make middleware-types read only. (#106)
In the README, we modified the request within a middleware, which is a bad practice. Middleware should treat the request as immutable, since middleware earlier in the chain may depend on the original request state. And to avoid such mistakes, lets make all middleware-related types read-only.
1 parent 6d0c28a commit 806da0b

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

splunklib/ai/README.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,14 @@ class ExampleMiddleware(AgentMiddleware):
476476
async def model_middleware(
477477
self, request: ModelRequest, handler: ModelMiddlewareHandler
478478
) -> ModelResponse:
479-
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
480-
return await handler(request)
479+
return await handler(
480+
ModelRequest(
481+
system_message=request.system_message.replace(
482+
"SECRET", "[REDACTED]"
483+
),
484+
state=request.state,
485+
)
486+
)
481487

482488
@override
483489
async def tool_middleware(
@@ -535,8 +541,14 @@ from splunklib.ai.middleware import (
535541
async def redact_system_prompt(
536542
request: ModelRequest, handler: ModelMiddlewareHandler
537543
) -> ModelResponse:
538-
request.system_message = request.system_message.replace("SECRET", "[REDACTED]")
539-
return await handler(request)
544+
return await handler(
545+
ModelRequest(
546+
system_message=request.system_message.replace(
547+
"SECRET", "[REDACTED]"
548+
),
549+
state=request.state,
550+
)
551+
)
540552
```
541553

542554
Example tool middleware:

splunklib/ai/middleware.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,27 @@ class AgentState:
4242
token_count: float
4343

4444

45-
@dataclass
45+
@dataclass(frozen=True)
4646
class ToolRequest:
4747
call: ToolCall
4848
state: AgentState
4949

5050

51-
@dataclass
51+
@dataclass(frozen=True)
5252
class ToolResponse:
5353
result: ToolResult | ToolFailureResult
5454

5555

5656
ToolMiddlewareHandler = Callable[[ToolRequest], Awaitable[ToolResponse]]
5757

5858

59-
@dataclass
59+
@dataclass(frozen=True)
6060
class SubagentRequest:
6161
call: SubagentCall
6262
state: AgentState
6363

6464

65-
@dataclass
65+
@dataclass(frozen=True)
6666
class SubagentResponse:
6767
result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult
6868

@@ -73,13 +73,13 @@ class SubagentResponse:
7373
]
7474

7575

76-
@dataclass
76+
@dataclass(frozen=True)
7777
class ModelRequest:
7878
system_message: str
7979
state: AgentState
8080

8181

82-
@dataclass
82+
@dataclass(frozen=True)
8383
class ModelResponse:
8484
message: AIMessage
8585
structured_output: Any | None = None
@@ -88,7 +88,7 @@ class ModelResponse:
8888
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
8989

9090

91-
@dataclass
91+
@dataclass(frozen=True)
9292
class AgentRequest:
9393
messages: list[BaseMessage]
9494

0 commit comments

Comments
 (0)