Skip to content

Commit 06959b9

Browse files
DeanChensjcopybara-github
authored andcommitted
fix(sessions): Prevent MissingGreenlet after append_event with asyncpg
Merges #5814 Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 933406737
1 parent 8f85260 commit 06959b9

2 files changed

Lines changed: 88 additions & 5 deletions

File tree

src/google/adk/sessions/database_session_service.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -760,13 +760,15 @@ async def append_event(self, session: Session, event: Event) -> Event:
760760
storage_session.update_time = update_time
761761
sql_session.add(schema.StorageEvent.from_event(session, event))
762762

763+
# Read revision fields before commit. Post-commit ORM attribute access
764+
# can lazy-load expired columns and trigger MissingGreenlet with asyncpg
765+
# when pool_pre_ping is enabled.
766+
last_update_time = storage_session.get_update_timestamp(is_sqlite)
767+
storage_update_marker = storage_session.get_update_marker()
763768
await sql_session.commit()
764769

765-
# Update timestamp with commit time
766-
session.last_update_time = storage_session.get_update_timestamp(
767-
is_sqlite
768-
)
769-
session._storage_update_marker = storage_session.get_update_marker()
770+
session.last_update_time = last_update_time
771+
session._storage_update_marker = storage_update_marker
770772

771773
# Also update the in-memory session
772774
await super().append_event(session=session, event=event)

tests/unittests/sessions/test_session_service.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,87 @@ def _spy_factory():
12541254
await service.close()
12551255

12561256

1257+
class _CommitOrderSpySession:
1258+
"""SQLAlchemy session spy that marks when commit() has completed."""
1259+
1260+
def __init__(self, real_session, on_committed):
1261+
self._real = real_session
1262+
self._on_committed = on_committed
1263+
1264+
async def __aenter__(self):
1265+
self._real = await self._real.__aenter__()
1266+
return self
1267+
1268+
async def __aexit__(self, *args):
1269+
return await self._real.__aexit__(*args)
1270+
1271+
async def commit(self):
1272+
result = await self._real.commit()
1273+
self._on_committed()
1274+
return result
1275+
1276+
def __getattr__(self, name):
1277+
return getattr(self._real, name)
1278+
1279+
1280+
@pytest.mark.asyncio
1281+
async def test_append_event_reads_storage_revision_before_commit():
1282+
"""append_event captures session revision before commit completes."""
1283+
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
1284+
await service._prepare_tables()
1285+
schema = service._get_schema_classes()
1286+
original_get_update_timestamp = schema.StorageSession.get_update_timestamp
1287+
original_get_update_marker = schema.StorageSession.get_update_marker
1288+
revision_read_state = {'committed': False, 'post_commit_reads': 0}
1289+
1290+
def _track_revision_read(original):
1291+
def wrapper(self, *args, **kwargs):
1292+
if revision_read_state['committed']:
1293+
revision_read_state['post_commit_reads'] += 1
1294+
return original(self, *args, **kwargs)
1295+
1296+
return wrapper
1297+
1298+
schema.StorageSession.get_update_timestamp = _track_revision_read(
1299+
original_get_update_timestamp
1300+
)
1301+
schema.StorageSession.get_update_marker = _track_revision_read(
1302+
original_get_update_marker
1303+
)
1304+
1305+
try:
1306+
session = await service.create_session(
1307+
app_name='app', user_id='user', session_id='s1'
1308+
)
1309+
event_timestamp = session.last_update_time + 10
1310+
event = Event(
1311+
invocation_id='inv1',
1312+
author='user',
1313+
timestamp=event_timestamp,
1314+
)
1315+
1316+
original_factory = service.database_session_factory
1317+
1318+
def _spy_factory():
1319+
return _CommitOrderSpySession(
1320+
original_factory(),
1321+
on_committed=lambda: revision_read_state.update({'committed': True}),
1322+
)
1323+
1324+
service.database_session_factory = _spy_factory
1325+
1326+
await service.append_event(session, event)
1327+
1328+
assert revision_read_state['post_commit_reads'] == 0
1329+
assert session.last_update_time == pytest.approx(event_timestamp, abs=1e-6)
1330+
assert session._storage_update_marker is not None
1331+
finally:
1332+
schema.StorageSession.get_update_timestamp = original_get_update_timestamp
1333+
schema.StorageSession.get_update_marker = original_get_update_marker
1334+
1335+
await service.close()
1336+
1337+
12571338
@pytest.mark.asyncio
12581339
async def test_delete_session_calls_rollback_on_commit_failure():
12591340
"""Verifies that a commit failure during delete_session triggers an explicit

0 commit comments

Comments
 (0)