Skip to content

Commit 5d451f6

Browse files
committed
Persist segment context
Signed-off-by: Edwin Yu <edwinyyyu@gmail.com>
1 parent 5ce19cb commit 5d451f6

2 files changed

Lines changed: 142 additions & 0 deletions

File tree

packages/server/server_tests/memmachine_server/episodic_memory/extra_memory/segment_linker/test_sqlalchemy_segment_linker.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from memmachine_server.common.filter.filter_parser import Comparison
1616
from memmachine_server.episodic_memory.extra_memory.data_types import (
17+
CitationContext,
18+
MessageContext,
1719
Segment,
1820
Text,
1921
)
@@ -44,6 +46,7 @@ def _seg(
4446
offset: int = 0,
4547
ts_offset_seconds: int = 0,
4648
text: str = "hello",
49+
context: MessageContext | CitationContext | None = None,
4750
properties: dict | None = None,
4851
) -> Segment:
4952
return Segment(
@@ -53,6 +56,7 @@ def _seg(
5356
offset=offset,
5457
timestamp=BASE_TIME + timedelta(seconds=ts_offset_seconds),
5558
block=Text(text=text),
59+
context=context,
5660
properties=properties or {},
5761
)
5862

@@ -150,6 +154,71 @@ async def test_register_with_properties(
150154
assert returned.properties == {"color": "red", "score": 42}
151155

152156

157+
@pytest.mark.asyncio
158+
async def test_register_with_message_context(
159+
partition: SQLAlchemySegmentLinkerPartition,
160+
) -> None:
161+
ctx = MessageContext(source="User")
162+
seg = _seg(context=ctx)
163+
deriv = uuid4()
164+
await partition.register_segments({seg: [deriv]})
165+
166+
result = await partition.get_segments_by_derivatives([deriv])
167+
returned = next(iter(result[deriv]))
168+
assert returned.context == ctx
169+
170+
171+
@pytest.mark.asyncio
172+
async def test_register_with_citation_context(
173+
partition: SQLAlchemySegmentLinkerPartition,
174+
) -> None:
175+
ctx = CitationContext(source="docs.txt", source_type="file", location="/tmp")
176+
seg = _seg(context=ctx)
177+
deriv = uuid4()
178+
await partition.register_segments({seg: [deriv]})
179+
180+
result = await partition.get_segments_by_derivatives([deriv])
181+
returned = next(iter(result[deriv]))
182+
assert returned.context == ctx
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_register_with_no_context(
187+
partition: SQLAlchemySegmentLinkerPartition,
188+
) -> None:
189+
seg = _seg()
190+
deriv = uuid4()
191+
await partition.register_segments({seg: [deriv]})
192+
193+
result = await partition.get_segments_by_derivatives([deriv])
194+
returned = next(iter(result[deriv]))
195+
assert returned.context is None
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_context_preserved_in_segment_contexts(
200+
partition: SQLAlchemySegmentLinkerPartition,
201+
) -> None:
202+
"""Context is preserved when retrieving segment contexts (backward/forward)."""
203+
ep = uuid4()
204+
ctx_user = MessageContext(source="User")
205+
ctx_assistant = MessageContext(source="Assistant")
206+
s0 = _seg(episode_uuid=ep, offset=0, ts_offset_seconds=0, context=ctx_user)
207+
s1 = _seg(episode_uuid=ep, offset=1, ts_offset_seconds=1, context=ctx_assistant)
208+
s2 = _seg(episode_uuid=ep, offset=2, ts_offset_seconds=2, context=ctx_user)
209+
deriv = uuid4()
210+
await partition.register_segments({s0: [deriv], s1: [deriv], s2: [deriv]})
211+
212+
result = await partition.get_segment_contexts(
213+
[s1.uuid], max_backward_segments=1, max_forward_segments=1
214+
)
215+
ctx = list(result[s1.uuid])
216+
assert len(ctx) == 3
217+
assert ctx[0].context == ctx_user
218+
assert ctx[1].context == ctx_assistant
219+
assert ctx[2].context == ctx_user
220+
221+
153222
@pytest.mark.asyncio
154223
async def test_register_active_validation(
155224
partition: SQLAlchemySegmentLinkerPartition,
@@ -1109,3 +1178,68 @@ async def registerer() -> None:
11091178

11101179
await asyncio.gather(purger(), registerer())
11111180
assert errors == []
1181+
1182+
1183+
@pytest.mark.integration
1184+
@pytest.mark.asyncio
1185+
async def test_pg_context_preserved_via_lateral_join(
1186+
pg_linker: SQLAlchemySegmentLinker,
1187+
) -> None:
1188+
"""Context is preserved when retrieved via the LATERAL join path (multiple seeds)."""
1189+
partition = pg_linker.get_partition(PARTITION_KEY)
1190+
ep = uuid4()
1191+
ctx_user = MessageContext(source="User")
1192+
ctx_assistant = MessageContext(source="Assistant")
1193+
s0 = _seg(episode_uuid=ep, offset=0, ts_offset_seconds=0, context=ctx_user)
1194+
s1 = _seg(episode_uuid=ep, offset=1, ts_offset_seconds=1, context=ctx_assistant)
1195+
s2 = _seg(episode_uuid=ep, offset=2, ts_offset_seconds=2, context=ctx_user)
1196+
s3 = _seg(episode_uuid=ep, offset=3, ts_offset_seconds=3, context=ctx_assistant)
1197+
s4 = _seg(episode_uuid=ep, offset=4, ts_offset_seconds=4, context=ctx_user)
1198+
deriv = uuid4()
1199+
await partition.register_segments(
1200+
{s0: [deriv], s1: [deriv], s2: [deriv], s3: [deriv], s4: [deriv]}
1201+
)
1202+
1203+
# Two seeds exercises the LATERAL join code path.
1204+
result = await partition.get_segment_contexts(
1205+
[s1.uuid, s3.uuid], max_backward_segments=1, max_forward_segments=1
1206+
)
1207+
1208+
ctx_a = list(result[s1.uuid])
1209+
assert len(ctx_a) == 3
1210+
assert ctx_a[0].context == ctx_user
1211+
assert ctx_a[1].context == ctx_assistant
1212+
assert ctx_a[2].context == ctx_user
1213+
1214+
ctx_b = list(result[s3.uuid])
1215+
assert len(ctx_b) == 3
1216+
assert ctx_b[0].context == ctx_user
1217+
assert ctx_b[1].context == ctx_assistant
1218+
assert ctx_b[2].context == ctx_user
1219+
1220+
1221+
@pytest.mark.integration
1222+
@pytest.mark.asyncio
1223+
async def test_pg_mixed_context_types(
1224+
pg_linker: SQLAlchemySegmentLinker,
1225+
) -> None:
1226+
"""Different context types (message, citation, None) round-trip correctly on PG."""
1227+
partition = pg_linker.get_partition(PARTITION_KEY)
1228+
ctx_msg = MessageContext(source="User")
1229+
ctx_cite = CitationContext(source="paper.pdf", source_type="file", location="p.3")
1230+
1231+
s_msg = _seg(ts_offset_seconds=0, context=ctx_msg)
1232+
s_cite = _seg(ts_offset_seconds=1, context=ctx_cite)
1233+
s_none = _seg(ts_offset_seconds=2)
1234+
1235+
d1, d2, d3 = uuid4(), uuid4(), uuid4()
1236+
await partition.register_segments({s_msg: [d1], s_cite: [d2], s_none: [d3]})
1237+
1238+
r1 = await partition.get_segments_by_derivatives([d1])
1239+
assert next(iter(r1[d1])).context == ctx_msg
1240+
1241+
r2 = await partition.get_segments_by_derivatives([d2])
1242+
assert next(iter(r2[d2])).context == ctx_cite
1243+
1244+
r3 = await partition.get_segments_by_derivatives([d3])
1245+
assert next(iter(r3[d3])).context is None

packages/server/src/memmachine_server/episodic_memory/extra_memory/segment_linker/sqlalchemy_segment_linker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from memmachine_server.episodic_memory.extra_memory.data_types import (
5454
Block,
55+
Context,
5556
Segment,
5657
)
5758
from memmachine_server.episodic_memory.extra_memory.segment_linker.segment_linker import (
@@ -63,6 +64,7 @@
6364
logger = logging.getLogger(__name__)
6465

6566
_JSON_AUTO = JSON().with_variant(JSONB, "postgresql")
67+
_ContextAdapter = TypeAdapter(Context | None)
6668
_BlockAdapter = TypeAdapter(Block)
6769

6870

@@ -94,6 +96,7 @@ class SegmentRow(BaseSegmentLinker):
9496
timestamp: MappedColumn[datetime] = mapped_column(
9597
DateTime(timezone=True), nullable=False
9698
)
99+
context = mapped_column(_JSON_AUTO, nullable=True)
97100
block: MappedColumn[Block] = mapped_column(_JSON_AUTO, nullable=False)
98101

99102
__table_args__ = (
@@ -275,6 +278,7 @@ async def _insert_links(
275278
"index": segment.index,
276279
"offset": segment.offset,
277280
"timestamp": segment.timestamp,
281+
"context": segment.context.model_dump() if segment.context else None,
278282
"block": segment.block.model_dump(),
279283
}
280284
for segment in links
@@ -534,6 +538,7 @@ async def _lateral_query(
534538
lateral_subquery.c.index,
535539
lateral_subquery.c.offset,
536540
lateral_subquery.c.timestamp,
541+
lateral_subquery.c.context,
537542
lateral_subquery.c.block,
538543
).select_from(seeds_subquery.join(lateral_subquery, true()))
539544

@@ -549,6 +554,7 @@ async def _lateral_query(
549554
index=row.index,
550555
offset=row.offset,
551556
timestamp=row.timestamp,
557+
context=row.context,
552558
block=row.block,
553559
)
554560
)
@@ -1046,13 +1052,15 @@ def _segment_from_segment_row(
10461052
properties: dict[str, PropertyValue],
10471053
) -> Segment:
10481054
"""Convert a SegmentRow and its properties into a Segment."""
1055+
context = _ContextAdapter.validate_python(row.context)
10491056
block = _BlockAdapter.validate_python(row.block)
10501057
return Segment(
10511058
uuid=row.uuid,
10521059
episode_uuid=row.episode_uuid,
10531060
index=row.index,
10541061
offset=row.offset,
10551062
timestamp=row.timestamp,
1063+
context=context,
10561064
block=block,
10571065
properties=properties,
10581066
)

0 commit comments

Comments
 (0)