|
23 | 23 | import logging |
24 | 24 | import os |
25 | 25 | import types |
26 | | -from typing import Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union |
| 26 | +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union |
27 | 27 |
|
28 | 28 | from google.api_core import client_options |
29 | 29 | from google.api_core import gapic_v1 |
|
46 | 46 | encryption_spec_v1beta1 as gca_encryption_spec_v1beta1, |
47 | 47 | ) |
48 | 48 |
|
| 49 | +try: |
| 50 | + import google.auth.aio |
| 51 | + |
| 52 | + AsyncCredentials = google.auth.aio.credentials.Credentials |
| 53 | + _HAS_ASYNC_CRED_DEPS = True |
| 54 | +except ImportError: |
| 55 | + AsyncCredentials = Any |
| 56 | + _HAS_ASYNC_CRED_DEPS = False |
| 57 | + |
49 | 58 | _TVertexAiServiceClientWithOverride = TypeVar( |
50 | 59 | "_TVertexAiServiceClientWithOverride", |
51 | 60 | bound=utils.VertexAiServiceClientWithOverride, |
@@ -121,6 +130,7 @@ def __init__(self): |
121 | 130 | self._api_transport = None |
122 | 131 | self._request_metadata = None |
123 | 132 | self._resource_type = None |
| 133 | + self._async_rest_credentials = None |
124 | 134 |
|
125 | 135 | def init( |
126 | 136 | self, |
@@ -590,15 +600,24 @@ def create_client( |
590 | 600 | } |
591 | 601 |
|
592 | 602 | # Do not pass "grpc", rely on gapic defaults unless "rest" is specified |
593 | | - if self._api_transport == "rest": |
594 | | - if "Async" in client_class.__name__: |
595 | | - # Warn user that "rest" is not supported and use grpc instead |
| 603 | + if self._api_transport == "rest" and "Async" in client_class.__name__: |
| 604 | + # User requests async rest |
| 605 | + if self._async_rest_credentials: |
| 606 | + # Rest async recieves credentials from _async_rest_credentials |
| 607 | + kwargs["credentials"] = self._async_rest_credentials |
| 608 | + kwargs["transport"] = "rest_asyncio" |
| 609 | + else: |
| 610 | + # Rest async was specified, but no async credentials were set. |
| 611 | + # Fallback to gRPC instead. |
596 | 612 | logging.warning( |
597 | | - "REST is not supported for async clients, " |
598 | | - + "falling back to grpc." |
| 613 | + "REST async clients requires async credentials set using " |
| 614 | + + "aiplatform.initializer._set_async_rest_credentials().\n" |
| 615 | + + "Falling back to grpc since no async rest credentials " |
| 616 | + + "were detected." |
599 | 617 | ) |
600 | | - else: |
601 | | - kwargs["transport"] = self._api_transport |
| 618 | + elif self._api_transport == "rest": |
| 619 | + # User requests sync REST |
| 620 | + kwargs["transport"] = self._api_transport |
602 | 621 |
|
603 | 622 | client = client_class(**kwargs) |
604 | 623 | # We only wrap the client if the request_metadata is set at the creation time. |
@@ -672,6 +691,29 @@ def __call__(self, *args, **kwargs): |
672 | 691 | ) |
673 | 692 |
|
674 | 693 |
|
| 694 | +def _set_async_rest_credentials(credentials: AsyncCredentials): |
| 695 | + """Private method to set async REST credentials.""" |
| 696 | + if global_config._api_transport != "rest": |
| 697 | + raise ValueError( |
| 698 | + "Async REST credentials can only be set when using REST transport." |
| 699 | + ) |
| 700 | + elif not _HAS_ASYNC_CRED_DEPS or not isinstance(credentials, AsyncCredentials): |
| 701 | + raise ValueError( |
| 702 | + "Async REST transport requires async credentials of type" |
| 703 | + + f"{AsyncCredentials} which is only supported in " |
| 704 | + + "google-auth >= 2.35.0.\n\n" |
| 705 | + + "Install the following dependencies:\n" |
| 706 | + + "pip install google-api-core[grpc, async_rest] >= 2.21.0\n" |
| 707 | + + "pip install google-auth[aiohttp] >= 2.35.0\n\n" |
| 708 | + + "Example usage:\n" |
| 709 | + + "from google.auth.aio.credentials import StaticCredentials\n" |
| 710 | + + "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n" |
| 711 | + + "aiplatform.initializer._set_async_rest_credentials(" |
| 712 | + + "credentials=async_credentials)" |
| 713 | + ) |
| 714 | + global_config._async_rest_credentials = credentials |
| 715 | + |
| 716 | + |
675 | 717 | def _get_function_name_from_stack_frame(frame) -> str: |
676 | 718 | """Gates fully qualified function or method name. |
677 | 719 |
|
|
0 commit comments