Skip to content

Commit bd4983b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add SDK workaround for double-wrapped Any response in async_retrieve_contexts.
PiperOrigin-RevId: 896740495
1 parent a196cda commit bd4983b

2 files changed

Lines changed: 38 additions & 2 deletions

File tree

vertexai/preview/rag/rag_retrieval.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from vertexai.preview.rag.utils import _gapic_utils
2626
from vertexai.preview.rag.utils import resources
2727

28+
from google.protobuf import any_pb2
29+
2830

2931
def retrieval_query(
3032
text: str,
@@ -528,7 +530,22 @@ async def async_retrieve_contexts(
528530
response_lro = await client.async_retrieve_contexts(
529531
request=request, timeout=timeout
530532
)
531-
response = await response_lro.result()
533+
try:
534+
response = await response_lro.result(timeout=timeout)
535+
except Exception as e:
536+
if response_lro.done():
537+
raw_op = response_lro.operation
538+
if raw_op.WhichOneof("result") == "response":
539+
any_response = raw_op.response
540+
inner_any = any_pb2.Any()
541+
if any_response.Unpack(inner_any):
542+
inner_any.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts"
543+
rag_contexts = aiplatform_v1beta1.RagContexts()
544+
if inner_any.Unpack(rag_contexts._pb):
545+
return aiplatform_v1beta1.AsyncRetrieveContextsResponse(
546+
contexts=rag_contexts
547+
)
548+
raise e
532549
except Exception as e:
533550
raise RuntimeError(
534551
"Failed in retrieving contexts asynchronously due to: ", e

vertexai/rag/rag_retrieval.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from vertexai.rag.utils import _gapic_utils
2525
from vertexai.rag.utils import resources
2626

27+
from google.protobuf import any_pb2
28+
2729

2830
def retrieval_query(
2931
text: str,
@@ -325,7 +327,24 @@ async def async_retrieve_contexts(
325327
response_lro = await client.async_retrieve_contexts(
326328
request=request, timeout=timeout
327329
)
328-
response = await response_lro.result()
330+
try:
331+
response = await response_lro.result(timeout=timeout)
332+
except Exception as e:
333+
if response_lro.done():
334+
raw_op = response_lro.operation
335+
if raw_op.WhichOneof("result") == "response":
336+
any_response = raw_op.response
337+
inner_any = any_pb2.Any()
338+
if any_response.Unpack(inner_any):
339+
inner_any.type_url = (
340+
"type.googleapis.com/google.cloud.aiplatform.v1.RagContexts"
341+
)
342+
rag_contexts = aiplatform_v1.RagContexts()
343+
if inner_any.Unpack(rag_contexts._pb):
344+
return aiplatform_v1.AsyncRetrieveContextsResponse(
345+
contexts=rag_contexts
346+
)
347+
raise e
329348
except Exception as e:
330349
raise RuntimeError(
331350
"Failed in retrieving contexts asynchronously due to: ", e

0 commit comments

Comments
 (0)