Skip to content

Commit df2fb62

Browse files
authored
Merge pull request #82 from stackhawk/fix/triage-tool-routing-and-type-safety
fix: triage tool routing, type safety, and CI hardening
2 parents 1e1ac54 + f463a45 commit df2fb62

11 files changed

Lines changed: 526 additions & 251 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ jobs:
2626
python -m pip install --upgrade pip
2727
pip install -r requirements.txt
2828
pip install .
29+
- name: Run mypy
30+
run: |
31+
mypy stackhawk_mcp/
2932
- name: Run tests
3033
run: |
3134
pytest --maxfail=1 --disable-warnings

CLAUDE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ export STACKHAWK_API_KEY="your-api-key-here"
8686
- Individual test files for each major feature area
8787
- API integration tests and schema validation tests included
8888

89+
### Tool Routing Tests (REQUIRED)
90+
91+
**Every MCP tool must have a routing test in `tests/test_findings_triage_routing.py`.**
92+
93+
These tests mock the target method and call the tool through the MCP handler, verifying:
94+
1. The handler resolves to a real method (not an `AttributeError` at runtime)
95+
2. The method is called on the correct object (`self` vs `self.client`)
96+
3. Arguments are forwarded correctly
97+
98+
When adding a new tool to `handle_call_tool`, add a corresponding test following the existing pattern. No API key needed — these are pure wiring tests.
99+
89100
## Development Notes
90101

91102
### Python Requirements

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ dev = [
2323
"pytest>=7.0.0",
2424
"pytest-asyncio>=0.21.0",
2525
"black>=23.0.0",
26-
"mypy>=1.0.0"
26+
"mypy>=1.0.0",
27+
"types-PyYAML>=6.0",
28+
"types-jsonschema>=4.0.0"
2729
]
2830

2931
[project.scripts]

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pytest>=7.0.0
1212
pytest-asyncio>=0.21.0
1313
black>=23.0.0
1414
mypy>=1.0.0
15+
types-PyYAML>=6.0
16+
types-jsonschema>=4.0.0
1517
bumpver>=2023.1129
1618

1719
# FastAPI dependencies

stackhawk_mcp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77

88
__version__ = "1.2.4"
99
__author__ = "StackHawk MCP Team"
10-
__email__ = "support@stackhawk.com"
10+
__email__ = "support@stackhawk.com"

stackhawk_mcp/__main__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
from .server import main
33

4-
def cli():
4+
5+
def cli() -> None:
56
asyncio.run(main())
67

8+
79
if __name__ == "__main__":
8-
cli()
10+
cli()

stackhawk_mcp/http_server.py

Lines changed: 72 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -30,125 +30,108 @@
3030
# Store active SSE connections
3131
active_connections = {}
3232

33-
def create_jsonrpc_response(id_value, result=None, error=None):
33+
34+
def create_jsonrpc_response(id_value: object, result: object = None, error: object = None) -> dict:
3435
"""Create a proper JSON-RPC 2.0 response"""
35-
response = {
36-
"jsonrpc": "2.0",
37-
"id": id_value
38-
}
36+
response = {"jsonrpc": "2.0", "id": id_value}
3937
if error:
4038
response["error"] = error
4139
else:
4240
response["result"] = result
4341
return response
4442

45-
def fix_tool_schema(tool):
43+
44+
def fix_tool_schema(tool: dict | object) -> dict | object:
4645
"""Fix tool schema to ensure outputSchema and annotations are objects"""
4746
if isinstance(tool, dict):
4847
# Ensure outputSchema is an object with proper type
4948
if tool.get("outputSchema") is None:
5049
tool["outputSchema"] = {"type": "object"}
5150
elif isinstance(tool["outputSchema"], dict) and tool["outputSchema"].get("type") is None:
5251
tool["outputSchema"]["type"] = "object"
53-
52+
5453
# Ensure annotations is an object
5554
if tool.get("annotations") is None:
5655
tool["annotations"] = {}
57-
56+
5857
# Ensure meta is an object
5958
if tool.get("meta") is None:
6059
tool["meta"] = {}
61-
60+
6261
return tool
6362

64-
async def handle_initialize_request(request_data):
63+
64+
async def handle_initialize_request(request_data: dict) -> dict:
6565
"""Handle MCP initialize request"""
6666
return create_jsonrpc_response(
6767
request_data.get("id"),
6868
{
6969
"protocolVersion": "2025-03-26",
70-
"capabilities": {
71-
"tools": {}
72-
},
73-
"serverInfo": {
74-
"name": "StackHawk MCP",
75-
"version": __version__
76-
}
77-
}
70+
"capabilities": {"tools": {}},
71+
"serverInfo": {"name": "StackHawk MCP", "version": __version__},
72+
},
7873
)
7974

80-
async def handle_list_tools_request(request_data):
75+
76+
async def handle_list_tools_request(request_data: dict) -> dict:
8177
"""Handle MCP list tools request"""
8278
try:
8379
tools = await mcp_server.list_tools()
8480
# Fix the tool schemas to ensure proper objects
85-
fixed_tools = [fix_tool_schema(t.dict() if hasattr(t, 'dict') else t) for t in tools]
86-
return create_jsonrpc_response(
87-
request_data.get("id"),
88-
{
89-
"tools": fixed_tools
90-
}
91-
)
81+
fixed_tools = [fix_tool_schema(t.dict() if hasattr(t, "dict") else t) for t in tools]
82+
return create_jsonrpc_response(request_data.get("id"), {"tools": fixed_tools})
9283
except Exception as e:
9384
return create_jsonrpc_response(
94-
request_data.get("id"),
95-
error={
96-
"code": -1,
97-
"message": str(e)
98-
}
85+
request_data.get("id"), error={"code": -1, "message": str(e)}
9986
)
10087

101-
async def handle_call_tool_request(request_data):
88+
89+
async def handle_call_tool_request(request_data: dict) -> dict:
10290
"""Handle MCP call tool request"""
10391
try:
10492
params = request_data.get("params", {})
10593
name = params.get("name")
10694
arguments = params.get("arguments", {})
107-
95+
10896
result = await mcp_server.call_tool(name, arguments)
10997
return create_jsonrpc_response(
11098
request_data.get("id"),
111-
{
112-
"content": [r.dict() if hasattr(r, 'dict') else r for r in result]
113-
}
99+
{"content": [r.dict() if hasattr(r, "dict") else r for r in result]},
114100
)
115101
except Exception as e:
116102
return create_jsonrpc_response(
117-
request_data.get("id"),
118-
error={
119-
"code": -1,
120-
"message": str(e)
121-
}
103+
request_data.get("id"), error={"code": -1, "message": str(e)}
122104
)
123105

106+
124107
@app.post("/mcp")
125-
async def mcp_endpoint(request: Request):
108+
async def mcp_endpoint(request: Request) -> Response:
126109
"""Main MCP endpoint that handles all JSON-RPC messages"""
127-
110+
128111
# Check Accept header
129112
accept_header = request.headers.get("accept", "")
130113
if "application/json" not in accept_header and "text/event-stream" not in accept_header:
131114
return JSONResponse(
132115
content={"error": "Accept header must include application/json or text/event-stream"},
133-
status_code=400
116+
status_code=400,
134117
)
135-
118+
136119
try:
137120
# Parse request body
138121
body = await request.body()
139122
if not body:
140123
return JSONResponse(content={"error": "Empty request body"}, status_code=400)
141-
124+
142125
data = json.loads(body)
143-
126+
144127
# Handle batched requests
145128
if isinstance(data, list):
146129
responses = []
147130
for item in data:
148131
response = await handle_jsonrpc_message(item)
149132
if response:
150133
responses.append(response)
151-
134+
152135
if len(responses) == 1:
153136
return JSONResponse(content=responses[0])
154137
else:
@@ -160,20 +143,21 @@ async def mcp_endpoint(request: Request):
160143
return JSONResponse(content=response)
161144
else:
162145
return Response(status_code=202) # Accepted with no body
163-
146+
164147
except json.JSONDecodeError:
165148
return JSONResponse(content={"error": "Invalid JSON"}, status_code=400)
166149
except Exception as e:
167150
return JSONResponse(content={"error": str(e)}, status_code=500)
168151

169-
async def handle_jsonrpc_message(message):
152+
153+
async def handle_jsonrpc_message(message: object) -> dict | None:
170154
"""Handle individual JSON-RPC message"""
171155
if not isinstance(message, dict):
172156
return None
173-
157+
174158
method = message.get("method")
175159
message_id = message.get("id")
176-
160+
177161
if method == "initialize":
178162
return await handle_initialize_request(message)
179163
elif method == "tools/list":
@@ -185,30 +169,26 @@ async def handle_jsonrpc_message(message):
185169
return None
186170
else:
187171
return create_jsonrpc_response(
188-
message_id,
189-
error={
190-
"code": -32601,
191-
"message": f"Method not found: {method}"
192-
}
172+
message_id, error={"code": -32601, "message": f"Method not found: {method}"}
193173
)
194174

175+
195176
@app.get("/mcp")
196-
async def mcp_sse_endpoint(request: Request):
177+
async def mcp_sse_endpoint(request: Request) -> Response:
197178
"""SSE endpoint for streaming responses"""
198-
179+
199180
# Check Accept header
200181
accept_header = request.headers.get("accept", "")
201182
if "text/event-stream" not in accept_header:
202183
return JSONResponse(
203-
content={"error": "Accept header must include text/event-stream"},
204-
status_code=405
184+
content={"error": "Accept header must include text/event-stream"}, status_code=405
205185
)
206-
186+
207187
# Generate connection ID
208188
connection_id = str(uuid.uuid4())
209189
active_connections[connection_id] = True
210-
211-
async def event_stream():
190+
191+
async def event_stream(): # type: ignore[no-untyped-def]
212192
try:
213193
while active_connections.get(connection_id, False):
214194
# For now, just keep the connection alive
@@ -219,41 +199,48 @@ async def event_stream():
219199
finally:
220200
if connection_id in active_connections:
221201
del active_connections[connection_id]
222-
202+
223203
return StreamingResponse(
224204
event_stream(),
225205
media_type="text/event-stream",
226206
headers={
227207
"Cache-Control": "no-cache",
228208
"Connection": "keep-alive",
229209
"Content-Type": "text/event-stream",
230-
}
210+
},
231211
)
232212

213+
233214
# Legacy endpoints for manual testing (keep these for now)
234215
@app.get("/")
235-
async def root():
236-
return JSONResponse(content={
237-
"jsonrpc": "2.0",
238-
"id": 1,
239-
"result": {
240-
"serverName": "StackHawk MCP",
241-
"serverVersion": __version__,
242-
"protocolVersion": "v1"
216+
async def root() -> JSONResponse:
217+
return JSONResponse(
218+
content={
219+
"jsonrpc": "2.0",
220+
"id": 1,
221+
"result": {
222+
"serverName": "StackHawk MCP",
223+
"serverVersion": __version__,
224+
"protocolVersion": "v1",
225+
},
243226
}
244-
})
227+
)
228+
245229

246230
@app.post("/")
247-
async def root_post():
248-
return JSONResponse(content={
249-
"jsonrpc": "2.0",
250-
"id": 1,
251-
"result": {
252-
"serverName": "StackHawk MCP",
253-
"serverVersion": __version__,
254-
"protocolVersion": "v1"
231+
async def root_post() -> JSONResponse:
232+
return JSONResponse(
233+
content={
234+
"jsonrpc": "2.0",
235+
"id": 1,
236+
"result": {
237+
"serverName": "StackHawk MCP",
238+
"serverVersion": __version__,
239+
"protocolVersion": "v1",
240+
},
255241
}
256-
})
242+
)
243+
257244

258245
if __name__ == "__main__":
259-
uvicorn.run("stackhawk_mcp.http_server:app", host="0.0.0.0", port=8080, reload=True)
246+
uvicorn.run("stackhawk_mcp.http_server:app", host="0.0.0.0", port=8080, reload=True)

0 commit comments

Comments
 (0)