Skip to content

Commit e9f1c88

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
fix: save artifact in streaming agent run with events when multiturn
PiperOrigin-RevId: 878043339
1 parent aa22f54 commit e9f1c88

2 files changed

Lines changed: 44 additions & 5 deletions

File tree

  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

vertexai/agent_engines/templates/adk.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,18 @@ async def _init_session(
683683
if request.events:
684684
for event in request.events:
685685
await session_service.append_event(session, Event(**event))
686+
if request.artifacts:
687+
await self._save_artifacts(session.id, artifact_service, request)
688+
return session
689+
690+
async def _save_artifacts(
691+
self,
692+
session_id: str,
693+
artifact_service: "BaseArtifactService",
694+
request: _StreamRunRequest,
695+
):
696+
"""Saves the artifacts."""
697+
app = self._tmpl_attrs.get("app")
686698
if request.artifacts:
687699
for artifact in request.artifacts:
688700
artifact = _Artifact(**artifact)
@@ -693,7 +705,7 @@ async def _init_session(
693705
saved_version = await artifact_service.save_artifact(
694706
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
695707
user_id=request.user_id,
696-
session_id=session.id,
708+
session_id=session_id,
697709
filename=artifact.file_name,
698710
artifact=version_data.data,
699711
)
@@ -707,7 +719,6 @@ async def _init_session(
707719
saved_version,
708720
version_data.version,
709721
)
710-
return session
711722

712723
async def _convert_response_events(
713724
self,
@@ -1209,6 +1220,12 @@ async def streaming_agent_run_with_events(self, request_json: str):
12091220
user_id=request.user_id,
12101221
session_id=request.session_id,
12111222
)
1223+
if session:
1224+
await self._save_artifacts(
1225+
session_id=request.session_id,
1226+
artifact_service=artifact_service,
1227+
request=request,
1228+
)
12121229
except ClientError:
12131230
pass
12141231
if not session:

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,23 @@ async def _init_session(
616616
if request.events:
617617
for event in request.events:
618618
await session_service.append_event(session, Event(**event))
619+
if request.artifacts:
620+
await self._save_artifacts(
621+
session_id=session.id,
622+
artifact_service=artifact_service,
623+
request=request,
624+
)
625+
626+
return session
627+
628+
async def _save_artifacts(
629+
self,
630+
session_id: str,
631+
artifact_service: "BaseArtifactService",
632+
request: _StreamRunRequest,
633+
):
634+
"""Saves the artifacts."""
635+
app = self._tmpl_attrs.get("app")
619636
if request.artifacts:
620637
for artifact in request.artifacts:
621638
artifact = _Artifact(**artifact)
@@ -624,9 +641,9 @@ async def _init_session(
624641
):
625642
version_data = _ArtifactVersion(**version_data)
626643
saved_version = await artifact_service.save_artifact(
627-
app_name=self._tmpl_attrs.get("app_name"),
644+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
628645
user_id=request.user_id,
629-
session_id=session.id,
646+
session_id=session_id,
630647
filename=artifact.file_name,
631648
artifact=version_data.data,
632649
)
@@ -640,7 +657,6 @@ async def _init_session(
640657
saved_version,
641658
version_data.version,
642659
)
643-
return session
644660

645661
async def _convert_response_events(
646662
self,
@@ -1043,6 +1059,12 @@ async def _invoke_agent_async():
10431059
user_id=request.user_id,
10441060
session_id=request.session_id,
10451061
)
1062+
if session:
1063+
await self._save_artifacts(
1064+
session_id=request.session_id,
1065+
artifact_service=artifact_service,
1066+
request=request,
1067+
)
10461068
except ClientError:
10471069
pass
10481070
if not session:

0 commit comments

Comments
 (0)