1414
1515from memmachine_server .common .filter .filter_parser import Comparison
1616from memmachine_server .episodic_memory .extra_memory .data_types import (
17+ CitationContext ,
18+ MessageContext ,
1719 Segment ,
1820 Text ,
1921)
@@ -44,6 +46,7 @@ def _seg(
4446 offset : int = 0 ,
4547 ts_offset_seconds : int = 0 ,
4648 text : str = "hello" ,
49+ context : MessageContext | CitationContext | None = None ,
4750 properties : dict | None = None ,
4851) -> Segment :
4952 return Segment (
@@ -53,6 +56,7 @@ def _seg(
5356 offset = offset ,
5457 timestamp = BASE_TIME + timedelta (seconds = ts_offset_seconds ),
5558 block = Text (text = text ),
59+ context = context ,
5660 properties = properties or {},
5761 )
5862
@@ -150,6 +154,71 @@ async def test_register_with_properties(
150154 assert returned .properties == {"color" : "red" , "score" : 42 }
151155
152156
157+ @pytest .mark .asyncio
158+ async def test_register_with_message_context (
159+ partition : SQLAlchemySegmentLinkerPartition ,
160+ ) -> None :
161+ ctx = MessageContext (source = "User" )
162+ seg = _seg (context = ctx )
163+ deriv = uuid4 ()
164+ await partition .register_segments ({seg : [deriv ]})
165+
166+ result = await partition .get_segments_by_derivatives ([deriv ])
167+ returned = next (iter (result [deriv ]))
168+ assert returned .context == ctx
169+
170+
171+ @pytest .mark .asyncio
172+ async def test_register_with_citation_context (
173+ partition : SQLAlchemySegmentLinkerPartition ,
174+ ) -> None :
175+ ctx = CitationContext (source = "docs.txt" , source_type = "file" , location = "/tmp" )
176+ seg = _seg (context = ctx )
177+ deriv = uuid4 ()
178+ await partition .register_segments ({seg : [deriv ]})
179+
180+ result = await partition .get_segments_by_derivatives ([deriv ])
181+ returned = next (iter (result [deriv ]))
182+ assert returned .context == ctx
183+
184+
185+ @pytest .mark .asyncio
186+ async def test_register_with_no_context (
187+ partition : SQLAlchemySegmentLinkerPartition ,
188+ ) -> None :
189+ seg = _seg ()
190+ deriv = uuid4 ()
191+ await partition .register_segments ({seg : [deriv ]})
192+
193+ result = await partition .get_segments_by_derivatives ([deriv ])
194+ returned = next (iter (result [deriv ]))
195+ assert returned .context is None
196+
197+
198+ @pytest .mark .asyncio
199+ async def test_context_preserved_in_segment_contexts (
200+ partition : SQLAlchemySegmentLinkerPartition ,
201+ ) -> None :
202+ """Context is preserved when retrieving segment contexts (backward/forward)."""
203+ ep = uuid4 ()
204+ ctx_user = MessageContext (source = "User" )
205+ ctx_assistant = MessageContext (source = "Assistant" )
206+ s0 = _seg (episode_uuid = ep , offset = 0 , ts_offset_seconds = 0 , context = ctx_user )
207+ s1 = _seg (episode_uuid = ep , offset = 1 , ts_offset_seconds = 1 , context = ctx_assistant )
208+ s2 = _seg (episode_uuid = ep , offset = 2 , ts_offset_seconds = 2 , context = ctx_user )
209+ deriv = uuid4 ()
210+ await partition .register_segments ({s0 : [deriv ], s1 : [deriv ], s2 : [deriv ]})
211+
212+ result = await partition .get_segment_contexts (
213+ [s1 .uuid ], max_backward_segments = 1 , max_forward_segments = 1
214+ )
215+ ctx = list (result [s1 .uuid ])
216+ assert len (ctx ) == 3
217+ assert ctx [0 ].context == ctx_user
218+ assert ctx [1 ].context == ctx_assistant
219+ assert ctx [2 ].context == ctx_user
220+
221+
153222@pytest .mark .asyncio
154223async def test_register_active_validation (
155224 partition : SQLAlchemySegmentLinkerPartition ,
@@ -1109,3 +1178,68 @@ async def registerer() -> None:
11091178
11101179 await asyncio .gather (purger (), registerer ())
11111180 assert errors == []
1181+
1182+
1183+ @pytest .mark .integration
1184+ @pytest .mark .asyncio
1185+ async def test_pg_context_preserved_via_lateral_join (
1186+ pg_linker : SQLAlchemySegmentLinker ,
1187+ ) -> None :
1188+ """Context is preserved when retrieved via the LATERAL join path (multiple seeds)."""
1189+ partition = pg_linker .get_partition (PARTITION_KEY )
1190+ ep = uuid4 ()
1191+ ctx_user = MessageContext (source = "User" )
1192+ ctx_assistant = MessageContext (source = "Assistant" )
1193+ s0 = _seg (episode_uuid = ep , offset = 0 , ts_offset_seconds = 0 , context = ctx_user )
1194+ s1 = _seg (episode_uuid = ep , offset = 1 , ts_offset_seconds = 1 , context = ctx_assistant )
1195+ s2 = _seg (episode_uuid = ep , offset = 2 , ts_offset_seconds = 2 , context = ctx_user )
1196+ s3 = _seg (episode_uuid = ep , offset = 3 , ts_offset_seconds = 3 , context = ctx_assistant )
1197+ s4 = _seg (episode_uuid = ep , offset = 4 , ts_offset_seconds = 4 , context = ctx_user )
1198+ deriv = uuid4 ()
1199+ await partition .register_segments (
1200+ {s0 : [deriv ], s1 : [deriv ], s2 : [deriv ], s3 : [deriv ], s4 : [deriv ]}
1201+ )
1202+
1203+ # Two seeds exercises the LATERAL join code path.
1204+ result = await partition .get_segment_contexts (
1205+ [s1 .uuid , s3 .uuid ], max_backward_segments = 1 , max_forward_segments = 1
1206+ )
1207+
1208+ ctx_a = list (result [s1 .uuid ])
1209+ assert len (ctx_a ) == 3
1210+ assert ctx_a [0 ].context == ctx_user
1211+ assert ctx_a [1 ].context == ctx_assistant
1212+ assert ctx_a [2 ].context == ctx_user
1213+
1214+ ctx_b = list (result [s3 .uuid ])
1215+ assert len (ctx_b ) == 3
1216+ assert ctx_b [0 ].context == ctx_user
1217+ assert ctx_b [1 ].context == ctx_assistant
1218+ assert ctx_b [2 ].context == ctx_user
1219+
1220+
1221+ @pytest .mark .integration
1222+ @pytest .mark .asyncio
1223+ async def test_pg_mixed_context_types (
1224+ pg_linker : SQLAlchemySegmentLinker ,
1225+ ) -> None :
1226+ """Different context types (message, citation, None) round-trip correctly on PG."""
1227+ partition = pg_linker .get_partition (PARTITION_KEY )
1228+ ctx_msg = MessageContext (source = "User" )
1229+ ctx_cite = CitationContext (source = "paper.pdf" , source_type = "file" , location = "p.3" )
1230+
1231+ s_msg = _seg (ts_offset_seconds = 0 , context = ctx_msg )
1232+ s_cite = _seg (ts_offset_seconds = 1 , context = ctx_cite )
1233+ s_none = _seg (ts_offset_seconds = 2 )
1234+
1235+ d1 , d2 , d3 = uuid4 (), uuid4 (), uuid4 ()
1236+ await partition .register_segments ({s_msg : [d1 ], s_cite : [d2 ], s_none : [d3 ]})
1237+
1238+ r1 = await partition .get_segments_by_derivatives ([d1 ])
1239+ assert next (iter (r1 [d1 ])).context == ctx_msg
1240+
1241+ r2 = await partition .get_segments_by_derivatives ([d2 ])
1242+ assert next (iter (r2 [d2 ])).context == ctx_cite
1243+
1244+ r3 = await partition .get_segments_by_derivatives ([d3 ])
1245+ assert next (iter (r3 [d3 ])).context is None
0 commit comments