diff --git a/packages/server/server_tests/memmachine_server/common/language_model/test_openai_responses_language_model.py b/packages/server/server_tests/memmachine_server/common/language_model/test_openai_responses_language_model.py index 0f8ca7a31..5e46682ed 100644 --- a/packages/server/server_tests/memmachine_server/common/language_model/test_openai_responses_language_model.py +++ b/packages/server/server_tests/memmachine_server/common/language_model/test_openai_responses_language_model.py @@ -448,3 +448,83 @@ async def test_metrics_collection(mock_async_openai, full_config): total_counter.increment.assert_any_call(value=150) latency_histogram.observe.assert_called_once() + + +@pytest.mark.asyncio +async def test_generate_response_strips_leading_think_block( + mock_async_openai, + minimal_config, +): + """generate_response strips a leading ... block from output_text.""" + mock_response = MagicMock() + mock_response.output_text = "internal reasoningHello, world!" + mock_response.output = None + mock_response.usage = None + + mock_client = mock_async_openai.return_value + mock_client.responses.create.return_value = mock_response + + lm = OpenAIResponsesLanguageModel(minimal_config) + content, _ = await lm.generate_response( + system_prompt="sys", + user_prompt="usr", + ) + + assert content == "Hello, world!" + + +@pytest.mark.asyncio +async def test_generate_response_strips_multiline_think_block( + mock_async_openai, + minimal_config, +): + """generate_response strips a multi-line ... block from output_text.""" + mock_response = MagicMock() + mock_response.output_text = '\nline one\nline two\n\n{"answer": 42}' + mock_response.output = None + mock_response.usage = None + + mock_client = mock_async_openai.return_value + mock_client.responses.create.return_value = mock_response + + lm = OpenAIResponsesLanguageModel(minimal_config) + content, _ = await lm.generate_response( + system_prompt="sys", + user_prompt="usr", + ) + + assert content == '{"answer": 42}' + + +@pytest.mark.asyncio +async def test_generate_parsed_response_fallback_strips_think_and_parses( + mock_async_openai, + minimal_config, +): + """generate_parsed_response falls back to stripping think blocks and parsing + output_text when output_parsed is None.""" + from pydantic import BaseModel as PydanticBaseModel + + class MyModel(PydanticBaseModel): + value: int + + mock_response = MagicMock() + mock_response.output_parsed = None + mock_response.output_text = 'reasoning here{"value": 7}' + mock_response.usage = None + + mock_client = mock_async_openai.return_value + # generate_parsed_response uses responses.parse, not responses.create + mock_client.with_options.return_value.responses.parse = AsyncMock( + return_value=mock_response + ) + + lm = OpenAIResponsesLanguageModel(minimal_config) + result = await lm.generate_parsed_response( + output_format=MyModel, + system_prompt="sys", + user_prompt="usr", + ) + + assert result is not None + assert result.value == 7 diff --git a/packages/server/server_tests/memmachine_server/semantic_memory/test_semantic_ingestion.py b/packages/server/server_tests/memmachine_server/semantic_memory/test_semantic_ingestion.py index 0deea017a..7fe0b2b93 100644 --- a/packages/server/server_tests/memmachine_server/semantic_memory/test_semantic_ingestion.py +++ b/packages/server/server_tests/memmachine_server/semantic_memory/test_semantic_ingestion.py @@ -6,6 +6,8 @@ import numpy as np import pytest import pytest_asyncio +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ValidationError from memmachine_server.common.data_types import ExternalServiceAPIError from memmachine_server.common.episode_store import ( @@ -1003,3 +1005,113 @@ async def test_user_tags_preserved_after_ingestion_and_consolidation( # "bugfix" and "decision" survive (each had < 2 entries) assert "bugfix" in remaining_tags assert "decision" in remaining_tags + + +def _make_validation_error() -> ValidationError: + """Construct a real pydantic ValidationError by feeding bad data to a trivial model.""" + + class _StrictModel(PydanticBaseModel): + x: int + + try: + _StrictModel.model_validate({"x": "not-an-int"}) + except ValidationError as exc: + return exc + raise AssertionError("Expected ValidationError was not raised") # pragma: no cover + + +@pytest.mark.asyncio +async def test_validation_error_marks_message_as_ingested( + ingestion_service: IngestionService, + semantic_storage: SemanticStorage, + episode_storage: EpisodeStorage, + monkeypatch, +): + """A pydantic.ValidationError from llm_feature_update causes the message to be + marked as ingested so it is not re-queued on the next cycle.""" + message_id = await add_history( + episode_storage, + content="message that triggers a ValidationError", + ) + await semantic_storage.add_history_to_set( + set_id="user-valderr", + history_id=message_id, + ) + + validation_error = _make_validation_error() + + async def mock_validation_error(*args, **kwargs): + raise validation_error + + monkeypatch.setattr( + "memmachine_server.semantic_memory.semantic_ingestion.llm_feature_update", + mock_validation_error, + ) + + await ingestion_service._process_single_set("user-valderr") + + # Message must NOT appear in the not-yet-ingested list + pending = await _collect( + semantic_storage.get_history_messages( + set_ids=["user-valderr"], + is_ingested=False, + ) + ) + assert message_id not in pending + + # Message MUST appear in the ingested list + ingested = await _collect( + semantic_storage.get_history_messages( + set_ids=["user-valderr"], + is_ingested=True, + ) + ) + assert message_id in ingested + + +@pytest.mark.asyncio +async def test_generic_exception_does_not_mark_ingested( + semantic_storage: SemanticStorage, + episode_storage: EpisodeStorage, + resource_retriever: MockResourceRetriever, + monkeypatch, +): + """A generic RuntimeError from llm_feature_update must NOT mark the message as + ingested — guards against future over-broad widening of the ValidationError catch.""" + message_id = await add_history( + episode_storage, + content="message that triggers a generic RuntimeError", + ) + await semantic_storage.add_history_to_set( + set_id="user-rterr", + history_id=message_id, + ) + + async def mock_runtime_error(*args, **kwargs): + raise RuntimeError("unexpected failure") + + monkeypatch.setattr( + "memmachine_server.semantic_memory.semantic_ingestion.llm_feature_update", + mock_runtime_error, + ) + + ingestion_service = IngestionService( + IngestionService.Params( + semantic_storage=semantic_storage, + history_store=episode_storage, + resource_retriever=resource_retriever.get_resources, + consolidated_threshold=2, + debug_fail_loudly=False, + ) + ) + + await ingestion_service._process_single_set("user-rterr") + + # Message must remain in the not-yet-ingested list (was NOT marked ingested) + pending = await _collect( + semantic_storage.get_history_messages( + set_ids=["user-rterr"], + is_ingested=False, + ) + ) + assert message_id in pending diff --git a/packages/server/src/memmachine_server/common/language_model/openai_responses_language_model.py b/packages/server/src/memmachine_server/common/language_model/openai_responses_language_model.py index 51727cbcb..41e7c7ef8 100644 --- a/packages/server/src/memmachine_server/common/language_model/openai_responses_language_model.py +++ b/packages/server/src/memmachine_server/common/language_model/openai_responses_language_model.py @@ -2,6 +2,7 @@ import asyncio import logging +import re from typing import Any, TypeVar, cast from uuid import uuid4 @@ -13,7 +14,7 @@ ResponseInputParam, ToolParam, ) -from pydantic import BaseModel, Field, InstanceOf +from pydantic import BaseModel, Field, InstanceOf, TypeAdapter from memmachine_server.common.data_types import ExternalServiceAPIError from memmachine_server.common.metrics_factory import MetricsFactory, OperationTracker @@ -24,6 +25,12 @@ logger = logging.getLogger(__name__) +_THINK_BLOCK_RE = re.compile(r".*?", re.DOTALL) + + +def _strip_think_blocks(text: str) -> str: + return _THINK_BLOCK_RE.sub("", text).strip() + class OpenAIResponsesLanguageModelParams(BaseModel): """ @@ -173,7 +180,18 @@ async def generate_parsed_response( self._collect_usage_metrics(response) - return response.output_parsed + if response.output_parsed is not None: + return response.output_parsed + + raw_text = response.output_text or "" + if raw_text: + raw = _strip_think_blocks(raw_text) + if raw: + return TypeAdapter(output_format).validate_python( + json_repair.loads(raw) + ) + + return None async def generate_response( self, @@ -310,7 +328,7 @@ async def _generate_response( # noqa: C901 self._collect_usage_metrics(response) if response.output is None: - return (response.output_text or "", [], 0, 0) + return (_strip_think_blocks(response.output_text or ""), [], 0, 0) function_calls_arguments: list[dict[str, Any]] = [] try: @@ -333,7 +351,7 @@ async def _generate_response( # noqa: C901 ) from e return ( - response.output_text or "", + _strip_think_blocks(response.output_text or ""), function_calls_arguments, response.usage.input_tokens if response.usage else 0, response.usage.output_tokens if response.usage else 0, diff --git a/packages/server/src/memmachine_server/semantic_memory/semantic_ingestion.py b/packages/server/src/memmachine_server/semantic_memory/semantic_ingestion.py index f49f491fe..697871fb8 100644 --- a/packages/server/src/memmachine_server/semantic_memory/semantic_ingestion.py +++ b/packages/server/src/memmachine_server/semantic_memory/semantic_ingestion.py @@ -7,7 +7,7 @@ from itertools import chain import numpy as np -from pydantic import BaseModel, Field, InstanceOf, TypeAdapter +from pydantic import BaseModel, Field, InstanceOf, TypeAdapter, ValidationError from memmachine_server.common.embedder import Embedder from memmachine_server.common.episode_store import Episode, EpisodeIdT, EpisodeStorage @@ -216,6 +216,21 @@ async def process_semantic_type( mark_messages.append(message.uid) continue + if isinstance(err, ValidationError): + logger.warning( + "Skipping message %s for semantic type %s due to " + "non-retryable parse error (%s): %s", + message.uid, + semantic_category.name, + type(err).__name__, + err, + ) + if self._debug_fail_loudly: + raise + if message.uid not in mark_messages: + mark_messages.append(message.uid) + continue + logger.exception( "Failed to process message %s for semantic type %s", message.uid,