Skip to content

Commit ae1f158

Browse files
committed
feat: wire clustering into ingestion pipeline and API
1 parent 1a5e28b commit ae1f158

28 files changed

Lines changed: 2074 additions & 502 deletions

File tree

.github/workflows/installation-test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ jobs:
7373
shell: bash
7474
run: |
7575
set -eo pipefail
76+
python -m pip install *common*.whl
7677
whl_name=$(ls *server*.whl)
7778
python -m pip install --find-links . "$whl_name"
7879
@@ -105,6 +106,7 @@ jobs:
105106
run: |
106107
set -eo pipefail
107108
export PYTHONUTF8=1
109+
python -m pip install *common*.whl
108110
whl_name=$(ls *server*.whl)
109111
python -m pip install --find-links . "$whl_name"
110112
@@ -124,6 +126,7 @@ jobs:
124126
shell: bash
125127
run: |
126128
set -eo pipefail
129+
python -m pip install *common*.whl
127130
whl_name=$(ls *server*.whl)
128131
python -m pip install --find-links . "$whl_name"
129132

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,4 @@ site/
223223

224224
# Ignore documentation generated by extensions
225225
.spelling
226+
.worktrees/

packages/client/client_tests/test_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ def test_get_semantic_memory_config(self, config, mock_client):
412412
"database": "postgres-db",
413413
"llm_model": "gpt-4",
414414
"embedding_model": "openai-embedder",
415+
"cluster_similarity_threshold": 0.4,
416+
"cluster_max_time_gap_seconds": 3600,
415417
}
416418
)
417419
result = config.get_semantic_memory_config()
@@ -420,6 +422,8 @@ def test_get_semantic_memory_config(self, config, mock_client):
420422
assert result.database == "postgres-db"
421423
assert result.llm_model == "gpt-4"
422424
assert result.embedding_model == "openai-embedder"
425+
assert result.cluster_similarity_threshold == 0.4
426+
assert result.cluster_max_time_gap_seconds == 3600
423427
mock_client.request.assert_called_once_with(
424428
"GET",
425429
"http://localhost:8080/api/v2/config/memory/semantic",
@@ -442,6 +446,9 @@ def test_update_semantic_memory_config(self, config, mock_client):
442446
embedding_model="openai-embedder",
443447
ingestion_trigger_messages=10,
444448
ingestion_trigger_age_seconds=3600,
449+
cluster_idle_ttl_seconds=1800,
450+
cluster_similarity_threshold=0.45,
451+
cluster_max_time_gap_seconds=7200,
445452
)
446453
assert isinstance(result, UpdateMemoryConfigResponse)
447454
assert result.success is True
@@ -457,6 +464,9 @@ def test_update_semantic_memory_config(self, config, mock_client):
457464
assert body["embedding_model"] == "openai-embedder"
458465
assert body["ingestion_trigger_messages"] == 10
459466
assert body["ingestion_trigger_age_seconds"] == 3600
467+
assert body["cluster_idle_ttl_seconds"] == 1800
468+
assert body["cluster_similarity_threshold"] == 0.45
469+
assert body["cluster_max_time_gap_seconds"] == 7200
460470

461471
def test_update_semantic_memory_config_partial(self, config, mock_client):
462472
mock_client.request.return_value = _mock_response(

packages/client/src/memmachine_client/config.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34+
def _set_if_not_none(target: dict[str, Any], key: str, value: object) -> None:
35+
if value is not None:
36+
target[key] = value
37+
38+
3439
class Config:
3540
"""
3641
Configuration interface for managing MemMachine server settings.
@@ -438,8 +443,12 @@ def update_semantic_memory_config(
438443
database: str | None = None,
439444
llm_model: str | None = None,
440445
embedding_model: str | None = None,
446+
cluster_split_reranker: str | None = None,
441447
ingestion_trigger_messages: int | None = None,
442448
ingestion_trigger_age_seconds: int | None = None,
449+
cluster_idle_ttl_seconds: int | None = None,
450+
cluster_similarity_threshold: float | None = None,
451+
cluster_max_time_gap_seconds: int | None = None,
443452
timeout: int | None = None,
444453
) -> UpdateMemoryConfigResponse:
445454
"""
@@ -450,8 +459,12 @@ def update_semantic_memory_config(
450459
database: Name of the database to use for semantic memory
451460
llm_model: Name of the language model to use for feature extraction
452461
embedding_model: Name of the embedder to use for semantic similarity
462+
cluster_split_reranker: Reranker ID used for cluster split scoring
453463
ingestion_trigger_messages: Number of messages before triggering ingestion
454464
ingestion_trigger_age_seconds: Age threshold in seconds for triggering ingestion
465+
cluster_idle_ttl_seconds: Idle TTL in seconds for empty cluster GC
466+
cluster_similarity_threshold: Cosine similarity threshold for clustering
467+
cluster_max_time_gap_seconds: Maximum time gap in seconds for clustering
455468
timeout: Request timeout in seconds (uses client default if not provided)
456469
457470
Returns:
@@ -463,14 +476,43 @@ def update_semantic_memory_config(
463476
464477
"""
465478
self._check_closed()
466-
spec = UpdateSemanticMemorySpec(
467-
enabled=enabled,
468-
database=database,
469-
llm_model=llm_model,
470-
embedding_model=embedding_model,
471-
ingestion_trigger_messages=ingestion_trigger_messages,
472-
ingestion_trigger_age_seconds=ingestion_trigger_age_seconds,
479+
spec_data: dict[str, Any] = {}
480+
_set_if_not_none(spec_data, "enabled", enabled)
481+
_set_if_not_none(spec_data, "database", database)
482+
_set_if_not_none(spec_data, "llm_model", llm_model)
483+
_set_if_not_none(spec_data, "embedding_model", embedding_model)
484+
_set_if_not_none(
485+
spec_data,
486+
"cluster_split_reranker",
487+
cluster_split_reranker,
488+
)
489+
_set_if_not_none(
490+
spec_data,
491+
"ingestion_trigger_messages",
492+
ingestion_trigger_messages,
473493
)
494+
_set_if_not_none(
495+
spec_data,
496+
"ingestion_trigger_age_seconds",
497+
ingestion_trigger_age_seconds,
498+
)
499+
_set_if_not_none(
500+
spec_data,
501+
"cluster_idle_ttl_seconds",
502+
cluster_idle_ttl_seconds,
503+
)
504+
_set_if_not_none(
505+
spec_data,
506+
"cluster_similarity_threshold",
507+
cluster_similarity_threshold,
508+
)
509+
_set_if_not_none(
510+
spec_data,
511+
"cluster_max_time_gap_seconds",
512+
cluster_max_time_gap_seconds,
513+
)
514+
515+
spec = UpdateSemanticMemorySpec(**spec_data)
474516
payload = spec.model_dump(exclude_none=True)
475517
try:
476518
response = self.client.request(

packages/common/src/memmachine_common/api/config_spec.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,30 @@ class SemanticMemoryConfigResponse(BaseModel):
137137
str | None,
138138
Field(default=None, description=SpecDoc.SEMANTIC_EMBEDDING_MODEL),
139139
]
140+
cluster_split_reranker: Annotated[
141+
str | None,
142+
Field(default=None, description=SpecDoc.SEMANTIC_CLUSTER_SPLIT_RERANKER),
143+
]
144+
cluster_similarity_threshold: Annotated[
145+
float | None,
146+
Field(default=None, description=SpecDoc.SEMANTIC_CLUSTER_SIMILARITY_THRESHOLD),
147+
]
148+
cluster_max_time_gap_seconds: Annotated[
149+
int | None,
150+
Field(default=None, description=SpecDoc.SEMANTIC_CLUSTER_MAX_TIME_GAP),
151+
]
152+
ingestion_trigger_messages: Annotated[
153+
int | None,
154+
Field(default=None, description=SpecDoc.SEMANTIC_INGESTION_MESSAGES),
155+
]
156+
ingestion_trigger_age_seconds: Annotated[
157+
int | None,
158+
Field(default=None, description=SpecDoc.SEMANTIC_INGESTION_AGE),
159+
]
160+
cluster_idle_ttl_seconds: Annotated[
161+
int | None,
162+
Field(default=None, description=SpecDoc.SEMANTIC_CLUSTER_IDLE_TTL),
163+
]
140164

141165

142166
class GetConfigResponse(BaseModel):
@@ -439,6 +463,10 @@ class UpdateSemanticMemorySpec(BaseModel):
439463
str | None,
440464
Field(default=None, description=SpecDoc.SEMANTIC_EMBEDDING_MODEL),
441465
]
466+
cluster_split_reranker: Annotated[
467+
str | None,
468+
Field(default=None, description=SpecDoc.SEMANTIC_CLUSTER_SPLIT_RERANKER),
469+
]
442470
ingestion_trigger_messages: Annotated[
443471
int | None,
444472
Field(default=None, gt=0, description=SpecDoc.SEMANTIC_INGESTION_MESSAGES),
@@ -447,6 +475,23 @@ class UpdateSemanticMemorySpec(BaseModel):
447475
int | None,
448476
Field(default=None, gt=0, description=SpecDoc.SEMANTIC_INGESTION_AGE),
449477
]
478+
cluster_idle_ttl_seconds: Annotated[
479+
int | None,
480+
Field(default=None, gt=0, description=SpecDoc.SEMANTIC_CLUSTER_IDLE_TTL),
481+
]
482+
cluster_similarity_threshold: Annotated[
483+
float | None,
484+
Field(
485+
default=None,
486+
ge=0.0,
487+
le=1.0,
488+
description=SpecDoc.SEMANTIC_CLUSTER_SIMILARITY_THRESHOLD,
489+
),
490+
]
491+
cluster_max_time_gap_seconds: Annotated[
492+
int | None,
493+
Field(default=None, gt=0, description=SpecDoc.SEMANTIC_CLUSTER_MAX_TIME_GAP),
494+
]
450495

451496

452497
class UpdateMemoryConfigSpec(BaseModel):

packages/common/src/memmachine_common/api/doc.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,28 @@ class SpecDoc:
456456
Must reference an embedder configured in the resources section."""
457457

458458
SEMANTIC_INGESTION_MESSAGES = """
459-
The number of uningested messages that triggers an ingestion cycle."""
459+
The number of pending messages in a semantic cluster that triggers ingestion
460+
for that cluster."""
460461

461462
SEMANTIC_INGESTION_AGE = """
462-
The maximum age (in seconds) of uningested messages before
463-
triggering an ingestion cycle."""
463+
The maximum age (in seconds) of the oldest pending message in a semantic
464+
cluster before triggering ingestion for that cluster."""
465+
466+
SEMANTIC_CLUSTER_IDLE_TTL = """
467+
The idle time-to-live (in seconds) for empty semantic clusters before they are
468+
garbage collected. Set to null to disable cluster GC."""
469+
470+
SEMANTIC_CLUSTER_SIMILARITY_THRESHOLD = """
471+
Cosine similarity threshold for grouping messages into semantic clusters.
472+
Higher values produce tighter clusters; lower values merge more messages."""
473+
474+
SEMANTIC_CLUSTER_MAX_TIME_GAP = """
475+
Maximum time gap (in seconds) allowed between messages in the same cluster.
476+
Set to null to disable time-based cluster splitting."""
477+
478+
SEMANTIC_CLUSTER_SPLIT_RERANKER = """
479+
Reranker ID used to score cluster split boundaries during ingestion.
480+
Defaults to the long-term memory reranker when unset."""
464481

465482
UPDATE_EPISODIC_MEMORY = """
466483
Partial update for episodic memory configuration. Only supplied
@@ -1118,8 +1135,9 @@ class RouterDoc:
11181135
- database: The database resource to use for storing semantic memories
11191136
- llm_model: The language model to use for feature extraction
11201137
- embedding_model: The embedder to use for semantic similarity
1121-
- ingestion_trigger_messages: Number of messages before triggering ingestion
1122-
- ingestion_trigger_age_seconds: Age threshold for triggering ingestion
1138+
- ingestion_trigger_messages: Pending messages in a cluster before ingestion
1139+
- ingestion_trigger_age_seconds: Oldest pending message age before ingestion
1140+
- cluster_idle_ttl_seconds: Idle TTL before empty clusters are garbage collected
11231141
"""
11241142

11251143
# --- Semantic Set Type API Router Docs ---

packages/server/server_tests/memmachine_server/common/configuration/test_semantic_conf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,23 @@ def test_semantic_config_timedelta_float():
3131
conf = SemanticMemoryConf(**raw_conf)
3232
assert conf.ingestion_trigger_messages == 24
3333
assert conf.ingestion_trigger_age == timedelta(minutes=2, milliseconds=500)
34+
35+
36+
def test_semantic_config_cluster_settings():
37+
raw_conf: dict[str, Any] = {
38+
"database": "database",
39+
"llm_model": "llm",
40+
"embedding_model": "embedding",
41+
"ingestion_trigger_messages": 5,
42+
"ingestion_trigger_age": "PT1M",
43+
"cluster_similarity_threshold": 0.45,
44+
"cluster_max_time_gap": 300,
45+
"cluster_idle_ttl": "PT2H",
46+
"config_database": "database",
47+
}
48+
49+
conf = SemanticMemoryConf(**raw_conf)
50+
51+
assert conf.cluster_similarity_threshold == 0.45
52+
assert conf.cluster_max_time_gap == timedelta(seconds=300)
53+
assert conf.cluster_idle_ttl == timedelta(hours=2)

packages/server/server_tests/memmachine_server/common/test_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, cast
23

34
import pytest
@@ -7,6 +8,7 @@
78
chunk_text_balanced,
89
cluster_texts,
910
extract_sentences,
11+
merge_async_iterators,
1012
unflatten_like,
1113
)
1214

@@ -229,3 +231,88 @@ def test_extract_sentences():
229231
}
230232
for sentence in expected_sentences:
231233
assert any(sentence in s for s in sentences)
234+
235+
236+
async def _evented_iterator(values: list[str], events: list[asyncio.Event]):
237+
for value, event in zip(values, events, strict=True):
238+
await event.wait()
239+
yield value
240+
241+
242+
async def _raising_iterator(event: asyncio.Event, error: Exception):
243+
await event.wait()
244+
raise error
245+
if False:
246+
yield None
247+
248+
249+
async def _blocking_iterator(cancel_event: asyncio.Event):
250+
try:
251+
await asyncio.Event().wait()
252+
yield None
253+
finally:
254+
cancel_event.set()
255+
256+
257+
@pytest.mark.asyncio
258+
async def test_merge_async_iterators_empty():
259+
results = [item async for item in merge_async_iterators([])]
260+
assert results == []
261+
262+
263+
@pytest.mark.asyncio
264+
async def test_merge_async_iterators_interleaves():
265+
events_a = [asyncio.Event(), asyncio.Event()]
266+
events_b = [asyncio.Event(), asyncio.Event()]
267+
merged = merge_async_iterators(
268+
[
269+
_evented_iterator(["a1", "a2"], events_a),
270+
_evented_iterator(["b1", "b2"], events_b),
271+
]
272+
)
273+
274+
events_b[0].set()
275+
assert await asyncio.wait_for(anext(merged), timeout=1) == "b1"
276+
events_a[0].set()
277+
assert await asyncio.wait_for(anext(merged), timeout=1) == "a1"
278+
events_b[1].set()
279+
assert await asyncio.wait_for(anext(merged), timeout=1) == "b2"
280+
events_a[1].set()
281+
assert await asyncio.wait_for(anext(merged), timeout=1) == "a2"
282+
283+
with pytest.raises(StopAsyncIteration):
284+
await asyncio.wait_for(anext(merged), timeout=1)
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_merge_async_iterators_propagates_error():
289+
ok_events = [asyncio.Event()]
290+
error_event = asyncio.Event()
291+
merged = merge_async_iterators(
292+
[
293+
_evented_iterator(["ok"], ok_events),
294+
_raising_iterator(error_event, ValueError("boom")),
295+
]
296+
)
297+
298+
ok_events[0].set()
299+
assert await asyncio.wait_for(anext(merged), timeout=1) == "ok"
300+
301+
error_event.set()
302+
with pytest.raises(ValueError, match="boom"):
303+
await asyncio.wait_for(anext(merged), timeout=1)
304+
305+
306+
@pytest.mark.asyncio
307+
async def test_merge_async_iterators_cancels_producers():
308+
cancel_event = asyncio.Event()
309+
start_event = asyncio.Event()
310+
merged = merge_async_iterators(
311+
[_evented_iterator(["first"], [start_event]), _blocking_iterator(cancel_event)]
312+
)
313+
314+
start_event.set()
315+
assert await asyncio.wait_for(anext(merged), timeout=1) == "first"
316+
await merged.aclose()
317+
318+
await asyncio.wait_for(cancel_event.wait(), timeout=1)

0 commit comments

Comments
 (0)