Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test: add regression coverage for pop_record truncation
  • Loading branch information
CompilError-bts committed Mar 31, 2026
commit a3aed9b2a813d3f54ea0dea762d041f07dccd50a
29 changes: 28 additions & 1 deletion astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def text_chat_stream(
raise NotImplementedError()

async def pop_record(self, context: list) -> None:
"""Pop earliest non-system records while preserving tool-call pairing."""
"""弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。"""
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

def _has_tool_calls(message: dict) -> bool:
return bool(message.get("tool_calls"))
Expand Down Expand Up @@ -199,8 +199,35 @@ def _pop_earliest_unit() -> int:
del context[start_idx : end_idx + 1]
return removed_count

def _peek_earliest_unit_count() -> int:
start_idx = _first_non_system_index()
if start_idx is None:
return 0

record = context[start_idx]
role = record.get("role")
end_idx = start_idx
if role == "assistant" and _has_tool_calls(record):
while end_idx + 1 < len(context) and (
context[end_idx + 1].get("role") == "tool"
):
end_idx += 1
elif role == "tool":
while end_idx + 1 < len(context) and (
context[end_idx + 1].get("role") == "tool"
):
end_idx += 1
return end_idx - start_idx + 1

removed = 0
while removed < 2:
next_unit_count = _peek_earliest_unit_count()
if next_unit_count == 0:
break
# Keep behavior close to the old "pop around 2 records" strategy,
# while still preserving tool-call atomicity.
if removed > 0 and removed + next_unit_count > 3:
break
removed_now = _pop_earliest_unit()
if removed_now == 0:
break
Expand Down
62 changes: 62 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,68 @@ async def test_pop_record_removes_leading_orphan_tool_messages():
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_normal_messages_no_regression():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{"role": "user", "content": "user1"},
{"role": "assistant", "content": "assistant1"},
{"role": "user", "content": "user2"},
{"role": "assistant", "content": "assistant2"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "user", "content": "user2"},
{"role": "assistant", "content": "assistant2"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_assistant_with_multiple_tool_calls():
provider = _make_provider()
try:
context = [
{"role": "system", "content": "system"},
{
"role": "assistant",
"tool_calls": [{"id": "call_1"}, {"id": "call_2"}],
"content": None,
},
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
{"role": "user", "content": "keep me"},
]

await provider.pop_record(context)

assert context == [
{"role": "system", "content": "system"},
{"role": "user", "content": "keep me"},
]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_pop_record_only_system_messages():
provider = _make_provider()
try:
context = [{"role": "system", "content": "system"}]

await provider.pop_record(context)

assert context == [{"role": "system", "content": "system"}]
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_groq_payload_drops_reasoning_content_from_assistant_history():
provider = _make_groq_provider()
Expand Down