Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,83 @@ async def test_metrics_collection(mock_async_openai, full_config):
total_counter.increment.assert_any_call(value=150)

latency_histogram.observe.assert_called_once()


@pytest.mark.asyncio
async def test_generate_response_strips_leading_think_block(
mock_async_openai,
minimal_config,
):
"""generate_response strips a leading <think>...</think> block from output_text."""
mock_response = MagicMock()
mock_response.output_text = "<think>internal reasoning</think>Hello, world!"
mock_response.output = None
mock_response.usage = None

mock_client = mock_async_openai.return_value
mock_client.responses.create.return_value = mock_response

lm = OpenAIResponsesLanguageModel(minimal_config)
content, _ = await lm.generate_response(
system_prompt="sys",
user_prompt="usr",
)

assert content == "Hello, world!"


@pytest.mark.asyncio
async def test_generate_response_strips_multiline_think_block(
mock_async_openai,
minimal_config,
):
"""generate_response strips a multi-line <think>...</think> block from output_text."""
mock_response = MagicMock()
mock_response.output_text = '<think>\nline one\nline two\n</think>\n{"answer": 42}'
mock_response.output = None
mock_response.usage = None

mock_client = mock_async_openai.return_value
mock_client.responses.create.return_value = mock_response

lm = OpenAIResponsesLanguageModel(minimal_config)
content, _ = await lm.generate_response(
system_prompt="sys",
user_prompt="usr",
)

assert content == '{"answer": 42}'


@pytest.mark.asyncio
async def test_generate_parsed_response_fallback_strips_think_and_parses(
mock_async_openai,
minimal_config,
):
"""generate_parsed_response falls back to stripping think blocks and parsing
output_text when output_parsed is None."""
from pydantic import BaseModel as PydanticBaseModel

class MyModel(PydanticBaseModel):
value: int

mock_response = MagicMock()
mock_response.output_parsed = None
mock_response.output_text = '<think>reasoning here</think>{"value": 7}'
mock_response.usage = None

mock_client = mock_async_openai.return_value
# generate_parsed_response uses responses.parse, not responses.create
mock_client.with_options.return_value.responses.parse = AsyncMock(
return_value=mock_response
)

lm = OpenAIResponsesLanguageModel(minimal_config)
result = await lm.generate_parsed_response(
output_format=MyModel,
system_prompt="sys",
user_prompt="usr",
)

assert result is not None
assert result.value == 7
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pytest
import pytest_asyncio
from pydantic import BaseModel as PydanticBaseModel
from pydantic import ValidationError

from memmachine_server.common.data_types import ExternalServiceAPIError
from memmachine_server.common.episode_store import (
Expand Down Expand Up @@ -1003,3 +1005,113 @@ async def test_user_tags_preserved_after_ingestion_and_consolidation(
# "bugfix" and "decision" survive (each had < 2 entries)
assert "bugfix" in remaining_tags
assert "decision" in remaining_tags


def _make_validation_error() -> ValidationError:
"""Construct a real pydantic ValidationError by feeding bad data to a trivial model."""

class _StrictModel(PydanticBaseModel):
x: int

try:
_StrictModel.model_validate({"x": "not-an-int"})
except ValidationError as exc:
return exc
raise AssertionError("Expected ValidationError was not raised") # pragma: no cover


@pytest.mark.asyncio
async def test_validation_error_marks_message_as_ingested(
ingestion_service: IngestionService,
semantic_storage: SemanticStorage,
episode_storage: EpisodeStorage,
monkeypatch,
):
"""A pydantic.ValidationError from llm_feature_update causes the message to be
marked as ingested so it is not re-queued on the next cycle."""
message_id = await add_history(
episode_storage,
content="message that triggers a ValidationError",
)
await semantic_storage.add_history_to_set(
set_id="user-valderr",
history_id=message_id,
)

validation_error = _make_validation_error()

async def mock_validation_error(*args, **kwargs):
raise validation_error

monkeypatch.setattr(
"memmachine_server.semantic_memory.semantic_ingestion.llm_feature_update",
mock_validation_error,
)

await ingestion_service._process_single_set("user-valderr")

# Message must NOT appear in the not-yet-ingested list
pending = await _collect(
semantic_storage.get_history_messages(
set_ids=["user-valderr"],
is_ingested=False,
)
)
assert message_id not in pending

# Message MUST appear in the ingested list
ingested = await _collect(
semantic_storage.get_history_messages(
set_ids=["user-valderr"],
is_ingested=True,
)
)
assert message_id in ingested


@pytest.mark.asyncio
async def test_generic_exception_does_not_mark_ingested(
semantic_storage: SemanticStorage,
episode_storage: EpisodeStorage,
resource_retriever: MockResourceRetriever,
monkeypatch,
):
"""A generic RuntimeError from llm_feature_update must NOT mark the message as
ingested — guards against future over-broad widening of the ValidationError catch."""
message_id = await add_history(
episode_storage,
content="message that triggers a generic RuntimeError",
)
await semantic_storage.add_history_to_set(
set_id="user-rterr",
history_id=message_id,
)

async def mock_runtime_error(*args, **kwargs):
raise RuntimeError("unexpected failure")

monkeypatch.setattr(
"memmachine_server.semantic_memory.semantic_ingestion.llm_feature_update",
mock_runtime_error,
)

ingestion_service = IngestionService(
IngestionService.Params(
semantic_storage=semantic_storage,
history_store=episode_storage,
resource_retriever=resource_retriever.get_resources,
consolidated_threshold=2,
debug_fail_loudly=False,
)
)

await ingestion_service._process_single_set("user-rterr")

# Message must remain in the not-yet-ingested list (was NOT marked ingested)
pending = await _collect(
semantic_storage.get_history_messages(
set_ids=["user-rterr"],
is_ingested=False,
)
)
assert message_id in pending
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import re
from typing import Any, TypeVar, cast
from uuid import uuid4

Expand All @@ -13,7 +14,7 @@
ResponseInputParam,
ToolParam,
)
from pydantic import BaseModel, Field, InstanceOf
from pydantic import BaseModel, Field, InstanceOf, TypeAdapter

from memmachine_server.common.data_types import ExternalServiceAPIError
from memmachine_server.common.metrics_factory import MetricsFactory, OperationTracker
Expand All @@ -24,6 +25,12 @@

logger = logging.getLogger(__name__)

_THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the response API is only supported by openAI, which does not return the "think" block.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue is using third-party models with the OpenAI API. OpenAI does a good job of avoiding this problem.

By baking the solution into the API wrapper instead of a sanitization layer, it can ruin intended outputs from a model that follows instructions well, and we have to play whack-a-mole on every single API which a model that does not follow instructions well may use.



def _strip_think_blocks(text: str) -> str:
return _THINK_BLOCK_RE.sub("", text).strip()
Comment thread
sscargal marked this conversation as resolved.


class OpenAIResponsesLanguageModelParams(BaseModel):
"""
Expand Down Expand Up @@ -173,7 +180,18 @@ async def generate_parsed_response(

self._collect_usage_metrics(response)

return response.output_parsed
if response.output_parsed is not None:
return response.output_parsed

raw_text = response.output_text or ""
if raw_text:
raw = _strip_think_blocks(raw_text)
if raw:
return TypeAdapter(output_format).validate_python(
json_repair.loads(raw)
)

return None

async def generate_response(
self,
Expand Down Expand Up @@ -310,7 +328,7 @@ async def _generate_response( # noqa: C901
self._collect_usage_metrics(response)

if response.output is None:
return (response.output_text or "", [], 0, 0)
return (_strip_think_blocks(response.output_text or ""), [], 0, 0)

function_calls_arguments: list[dict[str, Any]] = []
try:
Expand All @@ -333,7 +351,7 @@ async def _generate_response( # noqa: C901
) from e

return (
response.output_text or "",
_strip_think_blocks(response.output_text or ""),
function_calls_arguments,
response.usage.input_tokens if response.usage else 0,
response.usage.output_tokens if response.usage else 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import chain

import numpy as np
from pydantic import BaseModel, Field, InstanceOf, TypeAdapter
from pydantic import BaseModel, Field, InstanceOf, TypeAdapter, ValidationError

from memmachine_server.common.embedder import Embedder
from memmachine_server.common.episode_store import Episode, EpisodeIdT, EpisodeStorage
Expand Down Expand Up @@ -216,6 +216,21 @@ async def process_semantic_type(
mark_messages.append(message.uid)
continue

if isinstance(err, ValidationError):
logger.warning(
"Skipping message %s for semantic type %s due to "
"non-retryable parse error (%s): %s",
message.uid,
semantic_category.name,
type(err).__name__,
err,
)
Comment thread
sscargal marked this conversation as resolved.
if self._debug_fail_loudly:
raise
if message.uid not in mark_messages:
mark_messages.append(message.uid)
continue

logger.exception(
"Failed to process message %s for semantic type %s",
message.uid,
Expand Down
Loading