From 6d408eb60f070da0b356f5dfdf2a24487b815a6b Mon Sep 17 00:00:00 2001
From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com>
Date: Thu, 7 Oct 2021 19:10:45 +0000
Subject: [PATCH 1/6] chore(python): fix formatting issue in noxfile.py.j2
(#50)
---
.github/.OwlBot.lock.yaml | 2 +-
CONTRIBUTING.rst | 6 ++++--
noxfile.py | 2 +-
3 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml
index ee94722..76d0baa 100644
--- a/.github/.OwlBot.lock.yaml
+++ b/.github/.OwlBot.lock.yaml
@@ -1,3 +1,3 @@
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
- digest: sha256:6e7328583be8edd3ba8f35311c76a1ecbc823010279ccb6ab46b7a76e25eafcc
+ digest: sha256:4370ced27a324687ede5da07132dcdc5381993502a5e8a3e31e16dc631d026f0
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 6b51d1e..3151450 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -22,7 +22,7 @@ In order to add a feature:
documentation.
- The feature must work fully on the following CPython versions:
- 3.6, 3.7, 3.8 and 3.9 on both UNIX and Windows.
+ 3.6, 3.7, 3.8, 3.9 and 3.10 on both UNIX and Windows.
- The feature must not add unnecessary dependencies (where
"unnecessary" is of course subjective, but new dependencies should
@@ -72,7 +72,7 @@ We use `nox `__ to instrument our tests.
- To run a single unit test::
- $ nox -s unit-3.9 -- -k
+ $ nox -s unit-3.10 -- -k
.. note::
@@ -225,11 +225,13 @@ We support:
- `Python 3.7`_
- `Python 3.8`_
- `Python 3.9`_
+- `Python 3.10`_
.. _Python 3.6: https://docs.python.org/3.6/
.. _Python 3.7: https://docs.python.org/3.7/
.. _Python 3.8: https://docs.python.org/3.8/
.. _Python 3.9: https://docs.python.org/3.9/
+.. _Python 3.10: https://docs.python.org/3.10/
Supported versions can be found in our ``noxfile.py`` `config`_.
diff --git a/noxfile.py b/noxfile.py
index 935a924..2bb4cf7 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -29,7 +29,7 @@
DEFAULT_PYTHON_VERSION = "3.8"
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"]
-UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]
+UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9", "3.10"]
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
From 9b45d0eac22e78de5e522317b74140e5d261eee3 Mon Sep 17 00:00:00 2001
From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com>
Date: Fri, 8 Oct 2021 18:10:25 +0000
Subject: [PATCH 2/6] chore(python): Add kokoro configs for python 3.10 samples
testing (#51)
---
.github/.OwlBot.lock.yaml | 2 +-
.kokoro/samples/python3.10/common.cfg | 40 ++++++++++++++++++++
.kokoro/samples/python3.10/continuous.cfg | 6 +++
.kokoro/samples/python3.10/periodic-head.cfg | 11 ++++++
.kokoro/samples/python3.10/periodic.cfg | 6 +++
.kokoro/samples/python3.10/presubmit.cfg | 6 +++
6 files changed, 70 insertions(+), 1 deletion(-)
create mode 100644 .kokoro/samples/python3.10/common.cfg
create mode 100644 .kokoro/samples/python3.10/continuous.cfg
create mode 100644 .kokoro/samples/python3.10/periodic-head.cfg
create mode 100644 .kokoro/samples/python3.10/periodic.cfg
create mode 100644 .kokoro/samples/python3.10/presubmit.cfg
diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml
index 76d0baa..7d98291 100644
--- a/.github/.OwlBot.lock.yaml
+++ b/.github/.OwlBot.lock.yaml
@@ -1,3 +1,3 @@
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
- digest: sha256:4370ced27a324687ede5da07132dcdc5381993502a5e8a3e31e16dc631d026f0
+ digest: sha256:58f73ba196b5414782605236dd0712a73541b44ff2ff4d3a36ec41092dd6fa5b
diff --git a/.kokoro/samples/python3.10/common.cfg b/.kokoro/samples/python3.10/common.cfg
new file mode 100644
index 0000000..4343882
--- /dev/null
+++ b/.kokoro/samples/python3.10/common.cfg
@@ -0,0 +1,40 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Build logs will be here
+action {
+ define_artifacts {
+ regex: "**/*sponge_log.xml"
+ }
+}
+
+# Specify which tests to run
+env_vars: {
+ key: "RUN_TESTS_SESSION"
+ value: "py-3.10"
+}
+
+# Declare build specific Cloud project.
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "python-docs-samples-tests-310"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-tpu/.kokoro/test-samples.sh"
+}
+
+# Configure the docker image for kokoro-trampoline.
+env_vars: {
+ key: "TRAMPOLINE_IMAGE"
+ value: "gcr.io/cloud-devrel-kokoro-resources/python-samples-testing-docker"
+}
+
+# Download secrets for samples
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
+
+# Download trampoline resources.
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
+
+# Use the trampoline script to run in docker.
+build_file: "python-tpu/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.10/continuous.cfg b/.kokoro/samples/python3.10/continuous.cfg
new file mode 100644
index 0000000..a1c8d97
--- /dev/null
+++ b/.kokoro/samples/python3.10/continuous.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
diff --git a/.kokoro/samples/python3.10/periodic-head.cfg b/.kokoro/samples/python3.10/periodic-head.cfg
new file mode 100644
index 0000000..b200dbd
--- /dev/null
+++ b/.kokoro/samples/python3.10/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-tpu/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.10/periodic.cfg b/.kokoro/samples/python3.10/periodic.cfg
new file mode 100644
index 0000000..71cd1e5
--- /dev/null
+++ b/.kokoro/samples/python3.10/periodic.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "False"
+}
diff --git a/.kokoro/samples/python3.10/presubmit.cfg b/.kokoro/samples/python3.10/presubmit.cfg
new file mode 100644
index 0000000..a1c8d97
--- /dev/null
+++ b/.kokoro/samples/python3.10/presubmit.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
From 18b9ee0cff03b4f97071ef6c7a2bc3e613a01242 Mon Sep 17 00:00:00 2001
From: Anthonios Partheniou
Date: Thu, 14 Oct 2021 19:04:09 -0400
Subject: [PATCH 3/6] feat: add support for python 3.10 (#52)
---
setup.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/setup.py b/setup.py
index 0435207..1276af2 100644
--- a/setup.py
+++ b/setup.py
@@ -69,6 +69,8 @@
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
"Operating System :: OS Independent",
"Topic :: Internet",
],
From 47f27eb83bce2624987a7fec9492e9e906781166 Mon Sep 17 00:00:00 2001
From: Anthonios Partheniou
Date: Thu, 14 Oct 2021 20:10:11 -0400
Subject: [PATCH 4/6] chore: delete owlbot.py (#53)
Now that googleapis/synthtool#1244 is merged, owlbot.py is no longer required in this repo. We can add owlbot.py back in the future if repository specific customizations are needed.
---
.github/.OwlBot.lock.yaml | 2 +-
docs/index.rst | 3 ++-
owlbot.py | 43 ---------------------------------------
3 files changed, 3 insertions(+), 45 deletions(-)
delete mode 100644 owlbot.py
diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml
index 7d98291..ba7b2f7 100644
--- a/.github/.OwlBot.lock.yaml
+++ b/.github/.OwlBot.lock.yaml
@@ -1,3 +1,3 @@
docker:
image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
- digest: sha256:58f73ba196b5414782605236dd0712a73541b44ff2ff4d3a36ec41092dd6fa5b
+ digest: sha256:3728d8fd14daa46a96d04ce61c6451a3ac864dc48fb71eecbb4411f4a95618d4
diff --git a/docs/index.rst b/docs/index.rst
index 9e6f309..85bc9d5 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -2,6 +2,7 @@
.. include:: multiprocessing.rst
+
API Reference
-------------
.. toctree::
@@ -18,4 +19,4 @@ For a list of all ``google-cloud-tpu`` releases:
.. toctree::
:maxdepth: 2
- changelog
\ No newline at end of file
+ changelog
diff --git a/owlbot.py b/owlbot.py
deleted file mode 100644
index 43820d6..0000000
--- a/owlbot.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2021 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import synthtool as s
-import synthtool.gcp as gcp
-from synthtool.languages import python
-
-# ----------------------------------------------------------------------------
-# Copy the generated client from the owl-bot staging directory
-# ----------------------------------------------------------------------------
-
-default_version = "v1"
-
-for library in s.get_staging_dirs(default_version):
- s.move(library, excludes=["setup.py", "README.rst", "docs/index.rst"])
-s.remove_staging_dirs()
-
-# ----------------------------------------------------------------------------
-# Add templated files
-# ----------------------------------------------------------------------------
-
-templated_files = gcp.CommonTemplates().py_library(microgenerator=True)
-
-s.move(templated_files, excludes=[".coveragerc"]) # the microgenerator has a good coveragerc file
-
-python.py_samples(skip_readmes=True)
-
-# ----------------------------------------------------------------------------
-# Run blacken session
-# ----------------------------------------------------------------------------
-
-s.shell.run(["nox", "-s", "blacken"], hide_output=False)
From 72e3e8b955690b5f180af89a0a15a8870fd556a8 Mon Sep 17 00:00:00 2001
From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com>
Date: Fri, 15 Oct 2021 18:08:11 +0000
Subject: [PATCH 5/6] feat: add TPU v2alpha1 (#55)
- [ ] Regenerate this pull request now.
Committer: @rosbo
PiperOrigin-RevId: 403400668
Source-Link: https://github.com/googleapis/googleapis/commit/8f48b9778f9f875ac1931acccb1ff7bada71a372
Source-Link: https://github.com/googleapis/googleapis-gen/commit/f966fd02bd0d248e9adeffe3bb2daa53bf44a252
Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiZjk2NmZkMDJiZDBkMjQ4ZTlhZGVmZmUzYmIyZGFhNTNiZjQ0YTI1MiJ9
---
docs/index.rst | 11 +
docs/tpu_v2alpha1/services.rst | 6 +
docs/tpu_v2alpha1/tpu.rst | 10 +
docs/tpu_v2alpha1/types.rst | 7 +
google/cloud/tpu_v2alpha1/__init__.py | 90 +
google/cloud/tpu_v2alpha1/gapic_metadata.json | 153 +
google/cloud/tpu_v2alpha1/py.typed | 2 +
.../cloud/tpu_v2alpha1/services/__init__.py | 15 +
.../tpu_v2alpha1/services/tpu/__init__.py | 22 +
.../tpu_v2alpha1/services/tpu/async_client.py | 1108 +++++
.../cloud/tpu_v2alpha1/services/tpu/client.py | 1345 ++++++
.../cloud/tpu_v2alpha1/services/tpu/pagers.py | 411 ++
.../services/tpu/transports/__init__.py | 33 +
.../services/tpu/transports/base.py | 351 ++
.../services/tpu/transports/grpc.py | 597 +++
.../services/tpu/transports/grpc_asyncio.py | 611 +++
google/cloud/tpu_v2alpha1/types/__init__.py | 86 +
google/cloud/tpu_v2alpha1/types/cloud_tpu.py | 766 ++++
scripts/fixup_tpu_v2alpha1_keywords.py | 188 +
tests/unit/gapic/tpu_v2alpha1/__init__.py | 15 +
tests/unit/gapic/tpu_v2alpha1/test_tpu.py | 4026 +++++++++++++++++
21 files changed, 9853 insertions(+)
create mode 100644 docs/tpu_v2alpha1/services.rst
create mode 100644 docs/tpu_v2alpha1/tpu.rst
create mode 100644 docs/tpu_v2alpha1/types.rst
create mode 100644 google/cloud/tpu_v2alpha1/__init__.py
create mode 100644 google/cloud/tpu_v2alpha1/gapic_metadata.json
create mode 100644 google/cloud/tpu_v2alpha1/py.typed
create mode 100644 google/cloud/tpu_v2alpha1/services/__init__.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/__init__.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/async_client.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/client.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/pagers.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/base.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py
create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py
create mode 100644 google/cloud/tpu_v2alpha1/types/__init__.py
create mode 100644 google/cloud/tpu_v2alpha1/types/cloud_tpu.py
create mode 100644 scripts/fixup_tpu_v2alpha1_keywords.py
create mode 100644 tests/unit/gapic/tpu_v2alpha1/__init__.py
create mode 100644 tests/unit/gapic/tpu_v2alpha1/test_tpu.py
diff --git a/docs/index.rst b/docs/index.rst
index 85bc9d5..aff8054 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -2,6 +2,9 @@
.. include:: multiprocessing.rst
+This package includes clients for multiple versions of Cloud TPU.
+By default, you will get version ``tpu_v1``.
+
API Reference
-------------
@@ -11,6 +14,14 @@ API Reference
tpu_v1/services
tpu_v1/types
+API Reference
+-------------
+.. toctree::
+ :maxdepth: 2
+
+ tpu_v2alpha1/services
+ tpu_v2alpha1/types
+
Changelog
---------
diff --git a/docs/tpu_v2alpha1/services.rst b/docs/tpu_v2alpha1/services.rst
new file mode 100644
index 0000000..74c3c78
--- /dev/null
+++ b/docs/tpu_v2alpha1/services.rst
@@ -0,0 +1,6 @@
+Services for Google Cloud Tpu v2alpha1 API
+==========================================
+.. toctree::
+ :maxdepth: 2
+
+ tpu
diff --git a/docs/tpu_v2alpha1/tpu.rst b/docs/tpu_v2alpha1/tpu.rst
new file mode 100644
index 0000000..9b3906b
--- /dev/null
+++ b/docs/tpu_v2alpha1/tpu.rst
@@ -0,0 +1,10 @@
+Tpu
+---------------------
+
+.. automodule:: google.cloud.tpu_v2alpha1.services.tpu
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.tpu_v2alpha1.services.tpu.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/tpu_v2alpha1/types.rst b/docs/tpu_v2alpha1/types.rst
new file mode 100644
index 0000000..6c1d0f3
--- /dev/null
+++ b/docs/tpu_v2alpha1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Tpu v2alpha1 API
+=======================================
+
+.. automodule:: google.cloud.tpu_v2alpha1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/google/cloud/tpu_v2alpha1/__init__.py b/google/cloud/tpu_v2alpha1/__init__.py
new file mode 100644
index 0000000..9ffc130
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/__init__.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .services.tpu import TpuClient
+from .services.tpu import TpuAsyncClient
+
+from .types.cloud_tpu import AcceleratorType
+from .types.cloud_tpu import AccessConfig
+from .types.cloud_tpu import AttachedDisk
+from .types.cloud_tpu import CreateNodeRequest
+from .types.cloud_tpu import DeleteNodeRequest
+from .types.cloud_tpu import GenerateServiceIdentityRequest
+from .types.cloud_tpu import GenerateServiceIdentityResponse
+from .types.cloud_tpu import GetAcceleratorTypeRequest
+from .types.cloud_tpu import GetGuestAttributesRequest
+from .types.cloud_tpu import GetGuestAttributesResponse
+from .types.cloud_tpu import GetNodeRequest
+from .types.cloud_tpu import GetRuntimeVersionRequest
+from .types.cloud_tpu import GuestAttributes
+from .types.cloud_tpu import GuestAttributesEntry
+from .types.cloud_tpu import GuestAttributesValue
+from .types.cloud_tpu import ListAcceleratorTypesRequest
+from .types.cloud_tpu import ListAcceleratorTypesResponse
+from .types.cloud_tpu import ListNodesRequest
+from .types.cloud_tpu import ListNodesResponse
+from .types.cloud_tpu import ListRuntimeVersionsRequest
+from .types.cloud_tpu import ListRuntimeVersionsResponse
+from .types.cloud_tpu import NetworkConfig
+from .types.cloud_tpu import NetworkEndpoint
+from .types.cloud_tpu import Node
+from .types.cloud_tpu import OperationMetadata
+from .types.cloud_tpu import RuntimeVersion
+from .types.cloud_tpu import SchedulingConfig
+from .types.cloud_tpu import ServiceAccount
+from .types.cloud_tpu import ServiceIdentity
+from .types.cloud_tpu import StartNodeRequest
+from .types.cloud_tpu import StopNodeRequest
+from .types.cloud_tpu import Symptom
+from .types.cloud_tpu import UpdateNodeRequest
+
+__all__ = (
+ "TpuAsyncClient",
+ "AcceleratorType",
+ "AccessConfig",
+ "AttachedDisk",
+ "CreateNodeRequest",
+ "DeleteNodeRequest",
+ "GenerateServiceIdentityRequest",
+ "GenerateServiceIdentityResponse",
+ "GetAcceleratorTypeRequest",
+ "GetGuestAttributesRequest",
+ "GetGuestAttributesResponse",
+ "GetNodeRequest",
+ "GetRuntimeVersionRequest",
+ "GuestAttributes",
+ "GuestAttributesEntry",
+ "GuestAttributesValue",
+ "ListAcceleratorTypesRequest",
+ "ListAcceleratorTypesResponse",
+ "ListNodesRequest",
+ "ListNodesResponse",
+ "ListRuntimeVersionsRequest",
+ "ListRuntimeVersionsResponse",
+ "NetworkConfig",
+ "NetworkEndpoint",
+ "Node",
+ "OperationMetadata",
+ "RuntimeVersion",
+ "SchedulingConfig",
+ "ServiceAccount",
+ "ServiceIdentity",
+ "StartNodeRequest",
+ "StopNodeRequest",
+ "Symptom",
+ "TpuClient",
+ "UpdateNodeRequest",
+)
diff --git a/google/cloud/tpu_v2alpha1/gapic_metadata.json b/google/cloud/tpu_v2alpha1/gapic_metadata.json
new file mode 100644
index 0000000..0d306ce
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/gapic_metadata.json
@@ -0,0 +1,153 @@
+ {
+ "comment": "This file maps proto services/RPCs to the corresponding library clients/methods",
+ "language": "python",
+ "libraryPackage": "google.cloud.tpu_v2alpha1",
+ "protoPackage": "google.cloud.tpu.v2alpha1",
+ "schema": "1.0",
+ "services": {
+ "Tpu": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "TpuClient",
+ "rpcs": {
+ "CreateNode": {
+ "methods": [
+ "create_node"
+ ]
+ },
+ "DeleteNode": {
+ "methods": [
+ "delete_node"
+ ]
+ },
+ "GenerateServiceIdentity": {
+ "methods": [
+ "generate_service_identity"
+ ]
+ },
+ "GetAcceleratorType": {
+ "methods": [
+ "get_accelerator_type"
+ ]
+ },
+ "GetGuestAttributes": {
+ "methods": [
+ "get_guest_attributes"
+ ]
+ },
+ "GetNode": {
+ "methods": [
+ "get_node"
+ ]
+ },
+ "GetRuntimeVersion": {
+ "methods": [
+ "get_runtime_version"
+ ]
+ },
+ "ListAcceleratorTypes": {
+ "methods": [
+ "list_accelerator_types"
+ ]
+ },
+ "ListNodes": {
+ "methods": [
+ "list_nodes"
+ ]
+ },
+ "ListRuntimeVersions": {
+ "methods": [
+ "list_runtime_versions"
+ ]
+ },
+ "StartNode": {
+ "methods": [
+ "start_node"
+ ]
+ },
+ "StopNode": {
+ "methods": [
+ "stop_node"
+ ]
+ },
+ "UpdateNode": {
+ "methods": [
+ "update_node"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "TpuAsyncClient",
+ "rpcs": {
+ "CreateNode": {
+ "methods": [
+ "create_node"
+ ]
+ },
+ "DeleteNode": {
+ "methods": [
+ "delete_node"
+ ]
+ },
+ "GenerateServiceIdentity": {
+ "methods": [
+ "generate_service_identity"
+ ]
+ },
+ "GetAcceleratorType": {
+ "methods": [
+ "get_accelerator_type"
+ ]
+ },
+ "GetGuestAttributes": {
+ "methods": [
+ "get_guest_attributes"
+ ]
+ },
+ "GetNode": {
+ "methods": [
+ "get_node"
+ ]
+ },
+ "GetRuntimeVersion": {
+ "methods": [
+ "get_runtime_version"
+ ]
+ },
+ "ListAcceleratorTypes": {
+ "methods": [
+ "list_accelerator_types"
+ ]
+ },
+ "ListNodes": {
+ "methods": [
+ "list_nodes"
+ ]
+ },
+ "ListRuntimeVersions": {
+ "methods": [
+ "list_runtime_versions"
+ ]
+ },
+ "StartNode": {
+ "methods": [
+ "start_node"
+ ]
+ },
+ "StopNode": {
+ "methods": [
+ "stop_node"
+ ]
+ },
+ "UpdateNode": {
+ "methods": [
+ "update_node"
+ ]
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/google/cloud/tpu_v2alpha1/py.typed b/google/cloud/tpu_v2alpha1/py.typed
new file mode 100644
index 0000000..e122051
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/py.typed
@@ -0,0 +1,2 @@
+# Marker file for PEP 561.
+# The google-cloud-tpu package uses inline types.
diff --git a/google/cloud/tpu_v2alpha1/services/__init__.py b/google/cloud/tpu_v2alpha1/services/__init__.py
new file mode 100644
index 0000000..4de6597
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/__init__.py b/google/cloud/tpu_v2alpha1/services/tpu/__init__.py
new file mode 100644
index 0000000..d9a7a94
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from .client import TpuClient
+from .async_client import TpuAsyncClient
+
+__all__ = (
+ "TpuClient",
+ "TpuAsyncClient",
+)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/async_client.py b/google/cloud/tpu_v2alpha1/services/tpu/async_client.py
new file mode 100644
index 0000000..fdd1076
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/async_client.py
@@ -0,0 +1,1108 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+import functools
+import re
+from typing import Dict, Sequence, Tuple, Type, Union
+import pkg_resources
+
+import google.api_core.client_options as ClientOptions # type: ignore
+from google.api_core import exceptions as core_exceptions # type: ignore
+from google.api_core import gapic_v1 # type: ignore
+from google.api_core import retry as retries # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+from google.api_core import operation # type: ignore
+from google.api_core import operation_async # type: ignore
+from google.cloud.tpu_v2alpha1.services.tpu import pagers
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+from .transports.base import TpuTransport, DEFAULT_CLIENT_INFO
+from .transports.grpc_asyncio import TpuGrpcAsyncIOTransport
+from .client import TpuClient
+
+
+class TpuAsyncClient:
+ """Manages TPU nodes and other resources
+ TPU API v2alpha1
+ """
+
+ _client: TpuClient
+
+ DEFAULT_ENDPOINT = TpuClient.DEFAULT_ENDPOINT
+ DEFAULT_MTLS_ENDPOINT = TpuClient.DEFAULT_MTLS_ENDPOINT
+
+ accelerator_type_path = staticmethod(TpuClient.accelerator_type_path)
+ parse_accelerator_type_path = staticmethod(TpuClient.parse_accelerator_type_path)
+ node_path = staticmethod(TpuClient.node_path)
+ parse_node_path = staticmethod(TpuClient.parse_node_path)
+ runtime_version_path = staticmethod(TpuClient.runtime_version_path)
+ parse_runtime_version_path = staticmethod(TpuClient.parse_runtime_version_path)
+ common_billing_account_path = staticmethod(TpuClient.common_billing_account_path)
+ parse_common_billing_account_path = staticmethod(
+ TpuClient.parse_common_billing_account_path
+ )
+ common_folder_path = staticmethod(TpuClient.common_folder_path)
+ parse_common_folder_path = staticmethod(TpuClient.parse_common_folder_path)
+ common_organization_path = staticmethod(TpuClient.common_organization_path)
+ parse_common_organization_path = staticmethod(
+ TpuClient.parse_common_organization_path
+ )
+ common_project_path = staticmethod(TpuClient.common_project_path)
+ parse_common_project_path = staticmethod(TpuClient.parse_common_project_path)
+ common_location_path = staticmethod(TpuClient.common_location_path)
+ parse_common_location_path = staticmethod(TpuClient.parse_common_location_path)
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ TpuAsyncClient: The constructed client.
+ """
+ return TpuClient.from_service_account_info.__func__(TpuAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ TpuAsyncClient: The constructed client.
+ """
+ return TpuClient.from_service_account_file.__func__(TpuAsyncClient, filename, *args, **kwargs) # type: ignore
+
+ from_service_account_json = from_service_account_file
+
+ @property
+ def transport(self) -> TpuTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ TpuTransport: The transport used by the client instance.
+ """
+ return self._client.transport
+
+ get_transport_class = functools.partial(
+ type(TpuClient).get_transport_class, type(TpuClient)
+ )
+
+ def __init__(
+ self,
+ *,
+ credentials: ga_credentials.Credentials = None,
+ transport: Union[str, TpuTransport] = "grpc_asyncio",
+ client_options: ClientOptions = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ ) -> None:
+ """Instantiates the tpu client.
+
+ Args:
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ transport (Union[str, ~.TpuTransport]): The
+ transport to use. If set to None, a transport is chosen
+ automatically.
+ client_options (ClientOptions): Custom options for the client. It
+ won't take effect if a ``transport`` instance is provided.
+ (1) The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
+ environment variable can also be used to override the endpoint:
+ "always" (always use the default mTLS endpoint), "never" (always
+ use the default regular endpoint) and "auto" (auto switch to the
+ default mTLS endpoint if client certificate is present, this is
+ the default value). However, the ``api_endpoint`` property takes
+ precedence if provided.
+ (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ is "true", then the ``client_cert_source`` property can be used
+ to provide client certificate for mutual TLS transport. If
+ not provided, the default SSL client certificate will be used if
+ present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
+ set, no client certificate will be used.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ """
+ self._client = TpuClient(
+ credentials=credentials,
+ transport=transport,
+ client_options=client_options,
+ client_info=client_info,
+ )
+
+ async def list_nodes(
+ self,
+ request: cloud_tpu.ListNodesRequest = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListNodesAsyncPager:
+ r"""Lists nodes.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.ListNodesRequest`):
+ The request object. Request for
+ [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+ parent (:class:`str`):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListNodesAsyncPager:
+ Response for
+ [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.ListNodesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_nodes,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListNodesAsyncPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_node(
+ self,
+ request: cloud_tpu.GetNodeRequest = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.Node:
+ r"""Gets the details of a node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.GetNodeRequest`):
+ The request object. Request for
+ [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode].
+ name (:class:`str`):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.Node:
+ A TPU instance.
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.GetNodeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ async def create_node(
+ self,
+ request: cloud_tpu.CreateNodeRequest = None,
+ *,
+ parent: str = None,
+ node: cloud_tpu.Node = None,
+ node_id: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Creates a node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.CreateNodeRequest`):
+ The request object. Request for
+ [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode].
+ parent (:class:`str`):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ node (:class:`google.cloud.tpu_v2alpha1.types.Node`):
+ Required. The node.
+ This corresponds to the ``node`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ node_id (:class:`str`):
+ The unqualified resource name.
+ This corresponds to the ``node_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, node, node_id])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.CreateNodeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if node is not None:
+ request.node = node
+ if node_id is not None:
+ request.node_id = node_id
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.create_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_node(
+ self,
+ request: cloud_tpu.DeleteNodeRequest = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Deletes a node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.DeleteNodeRequest`):
+ The request object. Request for
+ [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode].
+ name (:class:`str`):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.DeleteNodeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.delete_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def stop_node(
+ self,
+ request: cloud_tpu.StopNodeRequest = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Stops a node. This operation is only available with
+ single TPU nodes.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.StopNodeRequest`):
+ The request object. Request for
+ [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ request = cloud_tpu.StopNodeRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.stop_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def start_node(
+ self,
+ request: cloud_tpu.StartNodeRequest = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Starts a node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.StartNodeRequest`):
+ The request object. Request for
+ [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ request = cloud_tpu.StartNodeRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.start_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def update_node(
+ self,
+ request: cloud_tpu.UpdateNodeRequest = None,
+ *,
+ node: cloud_tpu.Node = None,
+ update_mask: field_mask_pb2.FieldMask = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Updates the configurations of a node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.UpdateNodeRequest`):
+ The request object. Request for
+ [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode].
+ node (:class:`google.cloud.tpu_v2alpha1.types.Node`):
+ Required. The node. Only fields specified in update_mask
+ are updated.
+
+ This corresponds to the ``node`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
+ Required. Mask of fields from [Node][Tpu.Node] to
+ update. Supported fields: None.
+
+ This corresponds to the ``update_mask`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([node, update_mask])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.UpdateNodeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if node is not None:
+ request.node = node
+ if update_mask is not None:
+ request.update_mask = update_mask
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.update_node,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("node.name", request.node.name),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def generate_service_identity(
+ self,
+ request: cloud_tpu.GenerateServiceIdentityRequest = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.GenerateServiceIdentityResponse:
+ r"""Generates the Cloud TPU service identity for the
+ project.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityRequest`):
+ The request object. Request for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityResponse:
+ Response for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+
+ """
+ # Create or coerce a protobuf request object.
+ request = cloud_tpu.GenerateServiceIdentityRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.generate_service_identity,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ async def list_accelerator_types(
+ self,
+ request: cloud_tpu.ListAcceleratorTypesRequest = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListAcceleratorTypesAsyncPager:
+ r"""Lists accelerator types supported by this API.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest`):
+ The request object. Request for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+ parent (:class:`str`):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListAcceleratorTypesAsyncPager:
+ Response for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.ListAcceleratorTypesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_accelerator_types,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListAcceleratorTypesAsyncPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_accelerator_type(
+ self,
+ request: cloud_tpu.GetAcceleratorTypeRequest = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.AcceleratorType:
+ r"""Gets AcceleratorType.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.GetAcceleratorTypeRequest`):
+ The request object. Request for
+ [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType].
+ name (:class:`str`):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.AcceleratorType:
+ A accelerator type that a Node can be
+ configured with.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.GetAcceleratorTypeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_accelerator_type,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ async def list_runtime_versions(
+ self,
+ request: cloud_tpu.ListRuntimeVersionsRequest = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListRuntimeVersionsAsyncPager:
+ r"""Lists runtime versions supported by this API.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest`):
+ The request object. Request for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+ parent (:class:`str`):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListRuntimeVersionsAsyncPager:
+ Response for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.ListRuntimeVersionsRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_runtime_versions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListRuntimeVersionsAsyncPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_runtime_version(
+ self,
+ request: cloud_tpu.GetRuntimeVersionRequest = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.RuntimeVersion:
+ r"""Gets a runtime version.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.GetRuntimeVersionRequest`):
+ The request object. Request for
+ [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion].
+ name (:class:`str`):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.RuntimeVersion:
+ A runtime version that a Node can be
+ configured with.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = cloud_tpu.GetRuntimeVersionRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_runtime_version,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ async def get_guest_attributes(
+ self,
+ request: cloud_tpu.GetGuestAttributesRequest = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.GetGuestAttributesResponse:
+ r"""Retrieves the guest attributes for the node.
+
+ Args:
+ request (:class:`google.cloud.tpu_v2alpha1.types.GetGuestAttributesRequest`):
+ The request object. Request for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.GetGuestAttributesResponse:
+ Response for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+
+ """
+ # Create or coerce a protobuf request object.
+ request = cloud_tpu.GetGuestAttributesRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_guest_attributes,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.transport.close()
+
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+__all__ = ("TpuAsyncClient",)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/client.py b/google/cloud/tpu_v2alpha1/services/tpu/client.py
new file mode 100644
index 0000000..1c76efc
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/client.py
@@ -0,0 +1,1345 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+from distutils import util
+import os
+import re
+from typing import Dict, Optional, Sequence, Tuple, Type, Union
+import pkg_resources
+
+from google.api_core import client_options as client_options_lib # type: ignore
+from google.api_core import exceptions as core_exceptions # type: ignore
+from google.api_core import gapic_v1 # type: ignore
+from google.api_core import retry as retries # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport import mtls # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+from google.auth.exceptions import MutualTLSChannelError # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+from google.api_core import operation # type: ignore
+from google.api_core import operation_async # type: ignore
+from google.cloud.tpu_v2alpha1.services.tpu import pagers
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+from .transports.base import TpuTransport, DEFAULT_CLIENT_INFO
+from .transports.grpc import TpuGrpcTransport
+from .transports.grpc_asyncio import TpuGrpcAsyncIOTransport
+
+
+class TpuClientMeta(type):
+ """Metaclass for the Tpu client.
+
+ This provides class-level methods for building and retrieving
+ support objects (e.g. transport) without polluting the client instance
+ objects.
+ """
+
+ _transport_registry = OrderedDict() # type: Dict[str, Type[TpuTransport]]
+ _transport_registry["grpc"] = TpuGrpcTransport
+ _transport_registry["grpc_asyncio"] = TpuGrpcAsyncIOTransport
+
+ def get_transport_class(cls, label: str = None,) -> Type[TpuTransport]:
+ """Returns an appropriate transport class.
+
+ Args:
+ label: The name of the desired transport. If none is
+ provided, then the first transport in the registry is used.
+
+ Returns:
+ The transport class to use.
+ """
+ # If a specific transport is requested, return that one.
+ if label:
+ return cls._transport_registry[label]
+
+ # No transport is requested; return the default (that is, the first one
+ # in the dictionary).
+ return next(iter(cls._transport_registry.values()))
+
+
+class TpuClient(metaclass=TpuClientMeta):
+ """Manages TPU nodes and other resources
+ TPU API v2alpha1
+ """
+
+ @staticmethod
+ def _get_default_mtls_endpoint(api_endpoint):
+ """Converts api endpoint to mTLS endpoint.
+
+ Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
+ "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
+ Args:
+ api_endpoint (Optional[str]): the api endpoint to convert.
+ Returns:
+ str: converted mTLS api endpoint.
+ """
+ if not api_endpoint:
+ return api_endpoint
+
+ mtls_endpoint_re = re.compile(
+ r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?"
+ )
+
+ m = mtls_endpoint_re.match(api_endpoint)
+ name, mtls, sandbox, googledomain = m.groups()
+ if mtls or not googledomain:
+ return api_endpoint
+
+ if sandbox:
+ return api_endpoint.replace(
+ "sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
+ )
+
+ return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+
+ DEFAULT_ENDPOINT = "tpu.googleapis.com"
+ DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
+ DEFAULT_ENDPOINT
+ )
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ TpuClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_info(info)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ TpuClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_file(filename)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ from_service_account_json = from_service_account_file
+
+ @property
+ def transport(self) -> TpuTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ TpuTransport: The transport used by the client
+ instance.
+ """
+ return self._transport
+
+ @staticmethod
+ def accelerator_type_path(
+ project: str, location: str, accelerator_type: str,
+ ) -> str:
+ """Returns a fully-qualified accelerator_type string."""
+ return "projects/{project}/locations/{location}/acceleratorTypes/{accelerator_type}".format(
+ project=project, location=location, accelerator_type=accelerator_type,
+ )
+
+ @staticmethod
+ def parse_accelerator_type_path(path: str) -> Dict[str, str]:
+ """Parses a accelerator_type path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/acceleratorTypes/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def node_path(project: str, location: str, node: str,) -> str:
+ """Returns a fully-qualified node string."""
+ return "projects/{project}/locations/{location}/nodes/{node}".format(
+ project=project, location=location, node=node,
+ )
+
+ @staticmethod
+ def parse_node_path(path: str) -> Dict[str, str]:
+ """Parses a node path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/nodes/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def runtime_version_path(project: str, location: str, runtime_version: str,) -> str:
+ """Returns a fully-qualified runtime_version string."""
+ return "projects/{project}/locations/{location}/runtimeVersions/{runtime_version}".format(
+ project=project, location=location, runtime_version=runtime_version,
+ )
+
+ @staticmethod
+ def parse_runtime_version_path(path: str) -> Dict[str, str]:
+ """Parses a runtime_version path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/runtimeVersions/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_billing_account_path(billing_account: str,) -> str:
+ """Returns a fully-qualified billing_account string."""
+ return "billingAccounts/{billing_account}".format(
+ billing_account=billing_account,
+ )
+
+ @staticmethod
+ def parse_common_billing_account_path(path: str) -> Dict[str, str]:
+ """Parse a billing_account path into its component segments."""
+ m = re.match(r"^billingAccounts/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_folder_path(folder: str,) -> str:
+ """Returns a fully-qualified folder string."""
+ return "folders/{folder}".format(folder=folder,)
+
+ @staticmethod
+ def parse_common_folder_path(path: str) -> Dict[str, str]:
+ """Parse a folder path into its component segments."""
+ m = re.match(r"^folders/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_organization_path(organization: str,) -> str:
+ """Returns a fully-qualified organization string."""
+ return "organizations/{organization}".format(organization=organization,)
+
+ @staticmethod
+ def parse_common_organization_path(path: str) -> Dict[str, str]:
+ """Parse a organization path into its component segments."""
+ m = re.match(r"^organizations/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_project_path(project: str,) -> str:
+ """Returns a fully-qualified project string."""
+ return "projects/{project}".format(project=project,)
+
+ @staticmethod
+ def parse_common_project_path(path: str) -> Dict[str, str]:
+ """Parse a project path into its component segments."""
+ m = re.match(r"^projects/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_location_path(project: str, location: str,) -> str:
+ """Returns a fully-qualified location string."""
+ return "projects/{project}/locations/{location}".format(
+ project=project, location=location,
+ )
+
+ @staticmethod
+ def parse_common_location_path(path: str) -> Dict[str, str]:
+ """Parse a location path into its component segments."""
+ m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[ga_credentials.Credentials] = None,
+ transport: Union[str, TpuTransport, None] = None,
+ client_options: Optional[client_options_lib.ClientOptions] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ ) -> None:
+ """Instantiates the tpu client.
+
+ Args:
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ transport (Union[str, TpuTransport]): The
+ transport to use. If set to None, a transport is chosen
+ automatically.
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. It won't take effect if a ``transport`` instance is provided.
+ (1) The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
+ environment variable can also be used to override the endpoint:
+ "always" (always use the default mTLS endpoint), "never" (always
+ use the default regular endpoint) and "auto" (auto switch to the
+ default mTLS endpoint if client certificate is present, this is
+ the default value). However, the ``api_endpoint`` property takes
+ precedence if provided.
+ (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ is "true", then the ``client_cert_source`` property can be used
+ to provide client certificate for mutual TLS transport. If
+ not provided, the default SSL client certificate will be used if
+ present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
+ set, no client certificate will be used.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
+ creation failed for any reason.
+ """
+ if isinstance(client_options, dict):
+ client_options = client_options_lib.from_dict(client_options)
+ if client_options is None:
+ client_options = client_options_lib.ClientOptions()
+
+ # Create SSL credentials for mutual TLS if needed.
+ use_client_cert = bool(
+ util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
+ )
+
+ client_cert_source_func = None
+ is_mtls = False
+ if use_client_cert:
+ if client_options.client_cert_source:
+ is_mtls = True
+ client_cert_source_func = client_options.client_cert_source
+ else:
+ is_mtls = mtls.has_default_client_cert_source()
+ if is_mtls:
+ client_cert_source_func = mtls.default_client_cert_source()
+ else:
+ client_cert_source_func = None
+
+ # Figure out which api endpoint to use.
+ if client_options.api_endpoint is not None:
+ api_endpoint = client_options.api_endpoint
+ else:
+ use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
+ if use_mtls_env == "never":
+ api_endpoint = self.DEFAULT_ENDPOINT
+ elif use_mtls_env == "always":
+ api_endpoint = self.DEFAULT_MTLS_ENDPOINT
+ elif use_mtls_env == "auto":
+ if is_mtls:
+ api_endpoint = self.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = self.DEFAULT_ENDPOINT
+ else:
+ raise MutualTLSChannelError(
+ "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted "
+ "values: never, auto, always"
+ )
+
+ # Save or instantiate the transport.
+ # Ordinarily, we provide the transport, but allowing a custom transport
+ # instance provides an extensibility point for unusual situations.
+ if isinstance(transport, TpuTransport):
+ # transport is a TpuTransport instance.
+ if credentials or client_options.credentials_file:
+ raise ValueError(
+ "When providing a transport instance, "
+ "provide its credentials directly."
+ )
+ if client_options.scopes:
+ raise ValueError(
+ "When providing a transport instance, provide its scopes "
+ "directly."
+ )
+ self._transport = transport
+ else:
+ Transport = type(self).get_transport_class(transport)
+ self._transport = Transport(
+ credentials=credentials,
+ credentials_file=client_options.credentials_file,
+ host=api_endpoint,
+ scopes=client_options.scopes,
+ client_cert_source_for_mtls=client_cert_source_func,
+ quota_project_id=client_options.quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=True,
+ )
+
+ def list_nodes(
+ self,
+ request: Union[cloud_tpu.ListNodesRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListNodesPager:
+ r"""Lists nodes.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.ListNodesRequest, dict]):
+ The request object. Request for
+ [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+ parent (str):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListNodesPager:
+ Response for
+ [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.ListNodesRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.ListNodesRequest):
+ request = cloud_tpu.ListNodesRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.list_nodes]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__iter__` convenience method.
+ response = pagers.ListNodesPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_node(
+ self,
+ request: Union[cloud_tpu.GetNodeRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.Node:
+ r"""Gets the details of a node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.GetNodeRequest, dict]):
+ The request object. Request for
+ [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode].
+ name (str):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.Node:
+ A TPU instance.
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.GetNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.GetNodeRequest):
+ request = cloud_tpu.GetNodeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.get_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ def create_node(
+ self,
+ request: Union[cloud_tpu.CreateNodeRequest, dict] = None,
+ *,
+ parent: str = None,
+ node: cloud_tpu.Node = None,
+ node_id: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation.Operation:
+ r"""Creates a node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.CreateNodeRequest, dict]):
+ The request object. Request for
+ [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode].
+ parent (str):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ node (google.cloud.tpu_v2alpha1.types.Node):
+ Required. The node.
+ This corresponds to the ``node`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ node_id (str):
+ The unqualified resource name.
+ This corresponds to the ``node_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, node, node_id])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.CreateNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.CreateNodeRequest):
+ request = cloud_tpu.CreateNodeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if node is not None:
+ request.node = node
+ if node_id is not None:
+ request.node_id = node_id
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.create_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def delete_node(
+ self,
+ request: Union[cloud_tpu.DeleteNodeRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation.Operation:
+ r"""Deletes a node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.DeleteNodeRequest, dict]):
+ The request object. Request for
+ [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode].
+ name (str):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.DeleteNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.DeleteNodeRequest):
+ request = cloud_tpu.DeleteNodeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.delete_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def stop_node(
+ self,
+ request: Union[cloud_tpu.StopNodeRequest, dict] = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation.Operation:
+ r"""Stops a node. This operation is only available with
+ single TPU nodes.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.StopNodeRequest, dict]):
+ The request object. Request for
+ [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.StopNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.StopNodeRequest):
+ request = cloud_tpu.StopNodeRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.stop_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def start_node(
+ self,
+ request: Union[cloud_tpu.StartNodeRequest, dict] = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation.Operation:
+ r"""Starts a node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.StartNodeRequest, dict]):
+ The request object. Request for
+ [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.StartNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.StartNodeRequest):
+ request = cloud_tpu.StartNodeRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.start_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def update_node(
+ self,
+ request: Union[cloud_tpu.UpdateNodeRequest, dict] = None,
+ *,
+ node: cloud_tpu.Node = None,
+ update_mask: field_mask_pb2.FieldMask = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation.Operation:
+ r"""Updates the configurations of a node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.UpdateNodeRequest, dict]):
+ The request object. Request for
+ [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode].
+ node (google.cloud.tpu_v2alpha1.types.Node):
+ Required. The node. Only fields specified in update_mask
+ are updated.
+
+ This corresponds to the ``node`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ update_mask (google.protobuf.field_mask_pb2.FieldMask):
+ Required. Mask of fields from [Node][Tpu.Node] to
+ update. Supported fields: None.
+
+ This corresponds to the ``update_mask`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation.Operation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU
+ instance.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([node, update_mask])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.UpdateNodeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.UpdateNodeRequest):
+ request = cloud_tpu.UpdateNodeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if node is not None:
+ request.node = node
+ if update_mask is not None:
+ request.update_mask = update_mask
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.update_node]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("node.name", request.node.name),)
+ ),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Wrap the response in an operation future.
+ response = operation.from_gapic(
+ response,
+ self._transport.operations_client,
+ cloud_tpu.Node,
+ metadata_type=cloud_tpu.OperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def generate_service_identity(
+ self,
+ request: Union[cloud_tpu.GenerateServiceIdentityRequest, dict] = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.GenerateServiceIdentityResponse:
+ r"""Generates the Cloud TPU service identity for the
+ project.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityRequest, dict]):
+ The request object. Request for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityResponse:
+ Response for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.GenerateServiceIdentityRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.GenerateServiceIdentityRequest):
+ request = cloud_tpu.GenerateServiceIdentityRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[
+ self._transport.generate_service_identity
+ ]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ def list_accelerator_types(
+ self,
+ request: Union[cloud_tpu.ListAcceleratorTypesRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListAcceleratorTypesPager:
+ r"""Lists accelerator types supported by this API.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest, dict]):
+ The request object. Request for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+ parent (str):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListAcceleratorTypesPager:
+ Response for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.ListAcceleratorTypesRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.ListAcceleratorTypesRequest):
+ request = cloud_tpu.ListAcceleratorTypesRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.list_accelerator_types]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__iter__` convenience method.
+ response = pagers.ListAcceleratorTypesPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_accelerator_type(
+ self,
+ request: Union[cloud_tpu.GetAcceleratorTypeRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.AcceleratorType:
+ r"""Gets AcceleratorType.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.GetAcceleratorTypeRequest, dict]):
+ The request object. Request for
+ [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType].
+ name (str):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.AcceleratorType:
+ A accelerator type that a Node can be
+ configured with.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.GetAcceleratorTypeRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.GetAcceleratorTypeRequest):
+ request = cloud_tpu.GetAcceleratorTypeRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.get_accelerator_type]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ def list_runtime_versions(
+ self,
+ request: Union[cloud_tpu.ListRuntimeVersionsRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListRuntimeVersionsPager:
+ r"""Lists runtime versions supported by this API.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest, dict]):
+ The request object. Request for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+ parent (str):
+ Required. The parent resource name.
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.services.tpu.pagers.ListRuntimeVersionsPager:
+ Response for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.ListRuntimeVersionsRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.ListRuntimeVersionsRequest):
+ request = cloud_tpu.ListRuntimeVersionsRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.list_runtime_versions]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__iter__` convenience method.
+ response = pagers.ListRuntimeVersionsPager(
+ method=rpc, request=request, response=response, metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_runtime_version(
+ self,
+ request: Union[cloud_tpu.GetRuntimeVersionRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.RuntimeVersion:
+ r"""Gets a runtime version.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.GetRuntimeVersionRequest, dict]):
+ The request object. Request for
+ [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion].
+ name (str):
+ Required. The resource name.
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.RuntimeVersion:
+ A runtime version that a Node can be
+ configured with.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Sanity check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.GetRuntimeVersionRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.GetRuntimeVersionRequest):
+ request = cloud_tpu.GetRuntimeVersionRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.get_runtime_version]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ def get_guest_attributes(
+ self,
+ request: Union[cloud_tpu.GetGuestAttributesRequest, dict] = None,
+ *,
+ retry: retries.Retry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> cloud_tpu.GetGuestAttributesResponse:
+ r"""Retrieves the guest attributes for the node.
+
+ Args:
+ request (Union[google.cloud.tpu_v2alpha1.types.GetGuestAttributesRequest, dict]):
+ The request object. Request for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.tpu_v2alpha1.types.GetGuestAttributesResponse:
+ Response for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Minor optimization to avoid making a copy if the user passes
+ # in a cloud_tpu.GetGuestAttributesRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(request, cloud_tpu.GetGuestAttributesRequest):
+ request = cloud_tpu.GetGuestAttributesRequest(request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.get_guest_attributes]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+
+ # Done; return the response.
+ return response
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """Releases underlying transport's resources.
+
+ .. warning::
+ ONLY use as a context manager if the transport is NOT shared
+ with other clients! Exiting the with block will CLOSE the transport
+ and may cause errors in other clients!
+ """
+ self.transport.close()
+
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+__all__ = ("TpuClient",)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/pagers.py b/google/cloud/tpu_v2alpha1/services/tpu/pagers.py
new file mode 100644
index 0000000..c1859a1
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/pagers.py
@@ -0,0 +1,411 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import (
+ Any,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Sequence,
+ Tuple,
+ Optional,
+ Iterator,
+)
+
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+
+
+class ListNodesPager:
+ """A pager for iterating through ``list_nodes`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` object, and
+ provides an ``__iter__`` method to iterate through its
+ ``nodes`` field.
+
+ If there are more pages, the ``__iter__`` method will make additional
+ ``ListNodes`` requests and continue to iterate
+ through the ``nodes`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., cloud_tpu.ListNodesResponse],
+ request: cloud_tpu.ListNodesRequest,
+ response: cloud_tpu.ListNodesResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiate the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListNodesRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListNodesResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListNodesRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ def pages(self) -> Iterator[cloud_tpu.ListNodesResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __iter__(self) -> Iterator[cloud_tpu.Node]:
+ for page in self.pages:
+ yield from page.nodes
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
+
+
+class ListNodesAsyncPager:
+ """A pager for iterating through ``list_nodes`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` object, and
+ provides an ``__aiter__`` method to iterate through its
+ ``nodes`` field.
+
+ If there are more pages, the ``__aiter__`` method will make additional
+ ``ListNodes`` requests and continue to iterate
+ through the ``nodes`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., Awaitable[cloud_tpu.ListNodesResponse]],
+ request: cloud_tpu.ListNodesRequest,
+ response: cloud_tpu.ListNodesResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiates the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListNodesRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListNodesResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListNodesRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ async def pages(self) -> AsyncIterator[cloud_tpu.ListNodesResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = await self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __aiter__(self) -> AsyncIterator[cloud_tpu.Node]:
+ async def async_generator():
+ async for page in self.pages:
+ for response in page.nodes:
+ yield response
+
+ return async_generator()
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
+
+
+class ListAcceleratorTypesPager:
+ """A pager for iterating through ``list_accelerator_types`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` object, and
+ provides an ``__iter__`` method to iterate through its
+ ``accelerator_types`` field.
+
+ If there are more pages, the ``__iter__`` method will make additional
+ ``ListAcceleratorTypes`` requests and continue to iterate
+ through the ``accelerator_types`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., cloud_tpu.ListAcceleratorTypesResponse],
+ request: cloud_tpu.ListAcceleratorTypesRequest,
+ response: cloud_tpu.ListAcceleratorTypesResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiate the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListAcceleratorTypesRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ def pages(self) -> Iterator[cloud_tpu.ListAcceleratorTypesResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __iter__(self) -> Iterator[cloud_tpu.AcceleratorType]:
+ for page in self.pages:
+ yield from page.accelerator_types
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
+
+
+class ListAcceleratorTypesAsyncPager:
+ """A pager for iterating through ``list_accelerator_types`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` object, and
+ provides an ``__aiter__`` method to iterate through its
+ ``accelerator_types`` field.
+
+ If there are more pages, the ``__aiter__`` method will make additional
+ ``ListAcceleratorTypes`` requests and continue to iterate
+ through the ``accelerator_types`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., Awaitable[cloud_tpu.ListAcceleratorTypesResponse]],
+ request: cloud_tpu.ListAcceleratorTypesRequest,
+ response: cloud_tpu.ListAcceleratorTypesResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiates the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListAcceleratorTypesRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ async def pages(self) -> AsyncIterator[cloud_tpu.ListAcceleratorTypesResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = await self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __aiter__(self) -> AsyncIterator[cloud_tpu.AcceleratorType]:
+ async def async_generator():
+ async for page in self.pages:
+ for response in page.accelerator_types:
+ yield response
+
+ return async_generator()
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
+
+
+class ListRuntimeVersionsPager:
+ """A pager for iterating through ``list_runtime_versions`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` object, and
+ provides an ``__iter__`` method to iterate through its
+ ``runtime_versions`` field.
+
+ If there are more pages, the ``__iter__`` method will make additional
+ ``ListRuntimeVersions`` requests and continue to iterate
+ through the ``runtime_versions`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., cloud_tpu.ListRuntimeVersionsResponse],
+ request: cloud_tpu.ListRuntimeVersionsRequest,
+ response: cloud_tpu.ListRuntimeVersionsResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiate the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListRuntimeVersionsRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ def pages(self) -> Iterator[cloud_tpu.ListRuntimeVersionsResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __iter__(self) -> Iterator[cloud_tpu.RuntimeVersion]:
+ for page in self.pages:
+ yield from page.runtime_versions
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
+
+
+class ListRuntimeVersionsAsyncPager:
+ """A pager for iterating through ``list_runtime_versions`` requests.
+
+ This class thinly wraps an initial
+ :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` object, and
+ provides an ``__aiter__`` method to iterate through its
+ ``runtime_versions`` field.
+
+ If there are more pages, the ``__aiter__`` method will make additional
+ ``ListRuntimeVersions`` requests and continue to iterate
+ through the ``runtime_versions`` field on the
+ corresponding responses.
+
+ All the usual :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse`
+ attributes are available on the pager. If multiple requests are made, only
+ the most recent response is retained, and thus used for attribute lookup.
+ """
+
+ def __init__(
+ self,
+ method: Callable[..., Awaitable[cloud_tpu.ListRuntimeVersionsResponse]],
+ request: cloud_tpu.ListRuntimeVersionsRequest,
+ response: cloud_tpu.ListRuntimeVersionsResponse,
+ *,
+ metadata: Sequence[Tuple[str, str]] = ()
+ ):
+ """Instantiates the pager.
+
+ Args:
+ method (Callable): The method that was originally called, and
+ which instantiated this pager.
+ request (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest):
+ The initial request object.
+ response (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse):
+ The initial response object.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ """
+ self._method = method
+ self._request = cloud_tpu.ListRuntimeVersionsRequest(request)
+ self._response = response
+ self._metadata = metadata
+
+ def __getattr__(self, name: str) -> Any:
+ return getattr(self._response, name)
+
+ @property
+ async def pages(self) -> AsyncIterator[cloud_tpu.ListRuntimeVersionsResponse]:
+ yield self._response
+ while self._response.next_page_token:
+ self._request.page_token = self._response.next_page_token
+ self._response = await self._method(self._request, metadata=self._metadata)
+ yield self._response
+
+ def __aiter__(self) -> AsyncIterator[cloud_tpu.RuntimeVersion]:
+ async def async_generator():
+ async for page in self.pages:
+ for response in page.runtime_versions:
+ yield response
+
+ return async_generator()
+
+ def __repr__(self) -> str:
+ return "{0}<{1!r}>".format(self.__class__.__name__, self._response)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py
new file mode 100644
index 0000000..d3ede28
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+from typing import Dict, Type
+
+from .base import TpuTransport
+from .grpc import TpuGrpcTransport
+from .grpc_asyncio import TpuGrpcAsyncIOTransport
+
+
+# Compile a registry of transports.
+_transport_registry = OrderedDict() # type: Dict[str, Type[TpuTransport]]
+_transport_registry["grpc"] = TpuGrpcTransport
+_transport_registry["grpc_asyncio"] = TpuGrpcAsyncIOTransport
+
+__all__ = (
+ "TpuTransport",
+ "TpuGrpcTransport",
+ "TpuGrpcAsyncIOTransport",
+)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py
new file mode 100644
index 0000000..6cc209d
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import abc
+from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
+import packaging.version
+import pkg_resources
+
+import google.auth # type: ignore
+import google.api_core # type: ignore
+from google.api_core import exceptions as core_exceptions # type: ignore
+from google.api_core import gapic_v1 # type: ignore
+from google.api_core import retry as retries # type: ignore
+from google.api_core import operations_v1 # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.longrunning import operations_pb2 # type: ignore
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+try:
+ # google.auth.__version__ was added in 1.26.0
+ _GOOGLE_AUTH_VERSION = google.auth.__version__
+except AttributeError:
+ try: # try pkg_resources if it is available
+ _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
+ except pkg_resources.DistributionNotFound: # pragma: NO COVER
+ _GOOGLE_AUTH_VERSION = None
+
+
+class TpuTransport(abc.ABC):
+ """Abstract transport class for Tpu."""
+
+ AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",)
+
+ DEFAULT_HOST: str = "tpu.googleapis.com"
+
+ def __init__(
+ self,
+ *,
+ host: str = DEFAULT_HOST,
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ **kwargs,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is mutually exclusive with credentials.
+ scopes (Optional[Sequence[str]]): A list of scopes.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+ """
+ # Save the hostname. Default to port 443 (HTTPS) if none is specified.
+ if ":" not in host:
+ host += ":443"
+ self._host = host
+
+ scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)
+
+ # Save the scopes.
+ self._scopes = scopes
+
+ # If no credentials are provided, then determine the appropriate
+ # defaults.
+ if credentials and credentials_file:
+ raise core_exceptions.DuplicateCredentialArgs(
+ "'credentials_file' and 'credentials' are mutually exclusive"
+ )
+
+ if credentials_file is not None:
+ credentials, _ = google.auth.load_credentials_from_file(
+ credentials_file, **scopes_kwargs, quota_project_id=quota_project_id
+ )
+
+ elif credentials is None:
+ credentials, _ = google.auth.default(
+ **scopes_kwargs, quota_project_id=quota_project_id
+ )
+
+ # If the credentials are service account credentials, then always try to use self signed JWT.
+ if (
+ always_use_jwt_access
+ and isinstance(credentials, service_account.Credentials)
+ and hasattr(service_account.Credentials, "with_always_use_jwt_access")
+ ):
+ credentials = credentials.with_always_use_jwt_access(True)
+
+ # Save the credentials.
+ self._credentials = credentials
+
+ # TODO(busunkim): This method is in the base transport
+ # to avoid duplicating code across the transport classes. These functions
+ # should be deleted once the minimum required versions of google-auth is increased.
+
+ # TODO: Remove this function once google-auth >= 1.25.0 is required
+ @classmethod
+ def _get_scopes_kwargs(
+ cls, host: str, scopes: Optional[Sequence[str]]
+ ) -> Dict[str, Optional[Sequence[str]]]:
+ """Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""
+
+ scopes_kwargs = {}
+
+ if _GOOGLE_AUTH_VERSION and (
+ packaging.version.parse(_GOOGLE_AUTH_VERSION)
+ >= packaging.version.parse("1.25.0")
+ ):
+ scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
+ else:
+ scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}
+
+ return scopes_kwargs
+
+ def _prep_wrapped_messages(self, client_info):
+ # Precompute the wrapped methods.
+ self._wrapped_methods = {
+ self.list_nodes: gapic_v1.method.wrap_method(
+ self.list_nodes, default_timeout=None, client_info=client_info,
+ ),
+ self.get_node: gapic_v1.method.wrap_method(
+ self.get_node, default_timeout=None, client_info=client_info,
+ ),
+ self.create_node: gapic_v1.method.wrap_method(
+ self.create_node, default_timeout=None, client_info=client_info,
+ ),
+ self.delete_node: gapic_v1.method.wrap_method(
+ self.delete_node, default_timeout=None, client_info=client_info,
+ ),
+ self.stop_node: gapic_v1.method.wrap_method(
+ self.stop_node, default_timeout=None, client_info=client_info,
+ ),
+ self.start_node: gapic_v1.method.wrap_method(
+ self.start_node, default_timeout=None, client_info=client_info,
+ ),
+ self.update_node: gapic_v1.method.wrap_method(
+ self.update_node, default_timeout=None, client_info=client_info,
+ ),
+ self.generate_service_identity: gapic_v1.method.wrap_method(
+ self.generate_service_identity,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.list_accelerator_types: gapic_v1.method.wrap_method(
+ self.list_accelerator_types,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.get_accelerator_type: gapic_v1.method.wrap_method(
+ self.get_accelerator_type,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.list_runtime_versions: gapic_v1.method.wrap_method(
+ self.list_runtime_versions,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.get_runtime_version: gapic_v1.method.wrap_method(
+ self.get_runtime_version, default_timeout=None, client_info=client_info,
+ ),
+ self.get_guest_attributes: gapic_v1.method.wrap_method(
+ self.get_guest_attributes,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ }
+
+ def close(self):
+ """Closes resources associated with the transport.
+
+ .. warning::
+ Only call this method if the transport is NOT shared
+ with other clients - this may cause errors in other clients!
+ """
+ raise NotImplementedError()
+
+ @property
+ def operations_client(self) -> operations_v1.OperationsClient:
+ """Return the client designed to process long-running operations."""
+ raise NotImplementedError()
+
+ @property
+ def list_nodes(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListNodesRequest],
+ Union[cloud_tpu.ListNodesResponse, Awaitable[cloud_tpu.ListNodesResponse]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetNodeRequest], Union[cloud_tpu.Node, Awaitable[cloud_tpu.Node]]
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def create_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.CreateNodeRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def delete_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.DeleteNodeRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def stop_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.StopNodeRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def start_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.StartNodeRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def update_node(
+ self,
+ ) -> Callable[
+ [cloud_tpu.UpdateNodeRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def generate_service_identity(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GenerateServiceIdentityRequest],
+ Union[
+ cloud_tpu.GenerateServiceIdentityResponse,
+ Awaitable[cloud_tpu.GenerateServiceIdentityResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_accelerator_types(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListAcceleratorTypesRequest],
+ Union[
+ cloud_tpu.ListAcceleratorTypesResponse,
+ Awaitable[cloud_tpu.ListAcceleratorTypesResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_accelerator_type(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetAcceleratorTypeRequest],
+ Union[cloud_tpu.AcceleratorType, Awaitable[cloud_tpu.AcceleratorType]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_runtime_versions(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListRuntimeVersionsRequest],
+ Union[
+ cloud_tpu.ListRuntimeVersionsResponse,
+ Awaitable[cloud_tpu.ListRuntimeVersionsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_runtime_version(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetRuntimeVersionRequest],
+ Union[cloud_tpu.RuntimeVersion, Awaitable[cloud_tpu.RuntimeVersion]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_guest_attributes(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetGuestAttributesRequest],
+ Union[
+ cloud_tpu.GetGuestAttributesResponse,
+ Awaitable[cloud_tpu.GetGuestAttributesResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+
+__all__ = ("TpuTransport",)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py
new file mode 100644
index 0000000..e31acf6
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py
@@ -0,0 +1,597 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import warnings
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
+
+from google.api_core import grpc_helpers # type: ignore
+from google.api_core import operations_v1 # type: ignore
+from google.api_core import gapic_v1 # type: ignore
+import google.auth # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+
+import grpc # type: ignore
+
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.longrunning import operations_pb2 # type: ignore
+from .base import TpuTransport, DEFAULT_CLIENT_INFO
+
+
+class TpuGrpcTransport(TpuTransport):
+ """gRPC backend transport for Tpu.
+
+ Manages TPU nodes and other resources
+ TPU API v2alpha1
+
+ This class defines the same methods as the primary client, so the
+ primary client can load the underlying transport implementation
+ and call it.
+
+ It sends protocol buffers over the wire using gRPC (which is built on
+ top of HTTP/2); the ``grpcio`` package must be installed.
+ """
+
+ _stubs: Dict[str, Callable]
+
+ def __init__(
+ self,
+ *,
+ host: str = "tpu.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: str = None,
+ scopes: Sequence[str] = None,
+ channel: grpc.Channel = None,
+ api_mtls_endpoint: str = None,
+ client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
+ ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
+ quota_project_id: Optional[str] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ This argument is ignored if ``channel`` is provided.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional(Sequence[str])): A list of scopes. This argument is
+ ignored if ``channel`` is provided.
+ channel (Optional[grpc.Channel]): A ``Channel`` instance through
+ which to make calls.
+ api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
+ If provided, it overrides the ``host`` argument and tries to create
+ a mutual TLS channel with client SSL credentials from
+ ``client_cert_source`` or application default SSL credentials.
+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ Deprecated. A callback to provide client SSL certificate bytes and
+ private key bytes, both in PEM format. It is ignored if
+ ``api_mtls_endpoint`` is None.
+ ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
+ for the grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure a mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
+ creation failed for any reason.
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+ self._grpc_channel = None
+ self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
+
+ if channel:
+ # Ignore credentials if a channel was passed.
+ credentials = False
+ # If a channel was explicitly provided, set it.
+ self._grpc_channel = channel
+ self._ssl_channel_credentials = None
+
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
+
+ else:
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
+ )
+
+ if not self._grpc_channel:
+ self._grpc_channel = type(self).create_channel(
+ self._host,
+ credentials=self._credentials,
+ credentials_file=credentials_file,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
+ quota_project_id=quota_project_id,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
+
+ @classmethod
+ def create_channel(
+ cls,
+ host: str = "tpu.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: str = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ **kwargs,
+ ) -> grpc.Channel:
+ """Create and return a gRPC channel object.
+ Args:
+ host (Optional[str]): The host for the channel to use.
+ credentials (Optional[~.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify this application to the service. If
+ none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is mutually exclusive with credentials.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ kwargs (Optional[dict]): Keyword arguments, which are passed to the
+ channel creation.
+ Returns:
+ grpc.Channel: A gRPC channel object.
+
+ Raises:
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+
+ return grpc_helpers.create_channel(
+ host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
+ **kwargs,
+ )
+
+ @property
+ def grpc_channel(self) -> grpc.Channel:
+ """Return the channel designed to connect to this service.
+ """
+ return self._grpc_channel
+
+ @property
+ def operations_client(self) -> operations_v1.OperationsClient:
+ """Create the client designed to process long-running operations.
+
+ This property caches on the instance; repeated calls return the same
+ client.
+ """
+ # Sanity check: Only create a new client if we do not already have one.
+ if self._operations_client is None:
+ self._operations_client = operations_v1.OperationsClient(self.grpc_channel)
+
+ # Return the client from cache.
+ return self._operations_client
+
+ @property
+ def list_nodes(
+ self,
+ ) -> Callable[[cloud_tpu.ListNodesRequest], cloud_tpu.ListNodesResponse]:
+ r"""Return a callable for the list nodes method over gRPC.
+
+ Lists nodes.
+
+ Returns:
+ Callable[[~.ListNodesRequest],
+ ~.ListNodesResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_nodes" not in self._stubs:
+ self._stubs["list_nodes"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListNodes",
+ request_serializer=cloud_tpu.ListNodesRequest.serialize,
+ response_deserializer=cloud_tpu.ListNodesResponse.deserialize,
+ )
+ return self._stubs["list_nodes"]
+
+ @property
+ def get_node(self) -> Callable[[cloud_tpu.GetNodeRequest], cloud_tpu.Node]:
+ r"""Return a callable for the get node method over gRPC.
+
+ Gets the details of a node.
+
+ Returns:
+ Callable[[~.GetNodeRequest],
+ ~.Node]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_node" not in self._stubs:
+ self._stubs["get_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetNode",
+ request_serializer=cloud_tpu.GetNodeRequest.serialize,
+ response_deserializer=cloud_tpu.Node.deserialize,
+ )
+ return self._stubs["get_node"]
+
+ @property
+ def create_node(
+ self,
+ ) -> Callable[[cloud_tpu.CreateNodeRequest], operations_pb2.Operation]:
+ r"""Return a callable for the create node method over gRPC.
+
+ Creates a node.
+
+ Returns:
+ Callable[[~.CreateNodeRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "create_node" not in self._stubs:
+ self._stubs["create_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/CreateNode",
+ request_serializer=cloud_tpu.CreateNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["create_node"]
+
+ @property
+ def delete_node(
+ self,
+ ) -> Callable[[cloud_tpu.DeleteNodeRequest], operations_pb2.Operation]:
+ r"""Return a callable for the delete node method over gRPC.
+
+ Deletes a node.
+
+ Returns:
+ Callable[[~.DeleteNodeRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_node" not in self._stubs:
+ self._stubs["delete_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/DeleteNode",
+ request_serializer=cloud_tpu.DeleteNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["delete_node"]
+
+ @property
+ def stop_node(
+ self,
+ ) -> Callable[[cloud_tpu.StopNodeRequest], operations_pb2.Operation]:
+ r"""Return a callable for the stop node method over gRPC.
+
+ Stops a node. This operation is only available with
+ single TPU nodes.
+
+ Returns:
+ Callable[[~.StopNodeRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "stop_node" not in self._stubs:
+ self._stubs["stop_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/StopNode",
+ request_serializer=cloud_tpu.StopNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["stop_node"]
+
+ @property
+ def start_node(
+ self,
+ ) -> Callable[[cloud_tpu.StartNodeRequest], operations_pb2.Operation]:
+ r"""Return a callable for the start node method over gRPC.
+
+ Starts a node.
+
+ Returns:
+ Callable[[~.StartNodeRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "start_node" not in self._stubs:
+ self._stubs["start_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/StartNode",
+ request_serializer=cloud_tpu.StartNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["start_node"]
+
+ @property
+ def update_node(
+ self,
+ ) -> Callable[[cloud_tpu.UpdateNodeRequest], operations_pb2.Operation]:
+ r"""Return a callable for the update node method over gRPC.
+
+ Updates the configurations of a node.
+
+ Returns:
+ Callable[[~.UpdateNodeRequest],
+ ~.Operation]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "update_node" not in self._stubs:
+ self._stubs["update_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/UpdateNode",
+ request_serializer=cloud_tpu.UpdateNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["update_node"]
+
+ @property
+ def generate_service_identity(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GenerateServiceIdentityRequest],
+ cloud_tpu.GenerateServiceIdentityResponse,
+ ]:
+ r"""Return a callable for the generate service identity method over gRPC.
+
+ Generates the Cloud TPU service identity for the
+ project.
+
+ Returns:
+ Callable[[~.GenerateServiceIdentityRequest],
+ ~.GenerateServiceIdentityResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "generate_service_identity" not in self._stubs:
+ self._stubs["generate_service_identity"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GenerateServiceIdentity",
+ request_serializer=cloud_tpu.GenerateServiceIdentityRequest.serialize,
+ response_deserializer=cloud_tpu.GenerateServiceIdentityResponse.deserialize,
+ )
+ return self._stubs["generate_service_identity"]
+
+ @property
+ def list_accelerator_types(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListAcceleratorTypesRequest], cloud_tpu.ListAcceleratorTypesResponse
+ ]:
+ r"""Return a callable for the list accelerator types method over gRPC.
+
+ Lists accelerator types supported by this API.
+
+ Returns:
+ Callable[[~.ListAcceleratorTypesRequest],
+ ~.ListAcceleratorTypesResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_accelerator_types" not in self._stubs:
+ self._stubs["list_accelerator_types"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListAcceleratorTypes",
+ request_serializer=cloud_tpu.ListAcceleratorTypesRequest.serialize,
+ response_deserializer=cloud_tpu.ListAcceleratorTypesResponse.deserialize,
+ )
+ return self._stubs["list_accelerator_types"]
+
+ @property
+ def get_accelerator_type(
+ self,
+ ) -> Callable[[cloud_tpu.GetAcceleratorTypeRequest], cloud_tpu.AcceleratorType]:
+ r"""Return a callable for the get accelerator type method over gRPC.
+
+ Gets AcceleratorType.
+
+ Returns:
+ Callable[[~.GetAcceleratorTypeRequest],
+ ~.AcceleratorType]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_accelerator_type" not in self._stubs:
+ self._stubs["get_accelerator_type"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetAcceleratorType",
+ request_serializer=cloud_tpu.GetAcceleratorTypeRequest.serialize,
+ response_deserializer=cloud_tpu.AcceleratorType.deserialize,
+ )
+ return self._stubs["get_accelerator_type"]
+
+ @property
+ def list_runtime_versions(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListRuntimeVersionsRequest], cloud_tpu.ListRuntimeVersionsResponse
+ ]:
+ r"""Return a callable for the list runtime versions method over gRPC.
+
+ Lists runtime versions supported by this API.
+
+ Returns:
+ Callable[[~.ListRuntimeVersionsRequest],
+ ~.ListRuntimeVersionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_runtime_versions" not in self._stubs:
+ self._stubs["list_runtime_versions"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListRuntimeVersions",
+ request_serializer=cloud_tpu.ListRuntimeVersionsRequest.serialize,
+ response_deserializer=cloud_tpu.ListRuntimeVersionsResponse.deserialize,
+ )
+ return self._stubs["list_runtime_versions"]
+
+ @property
+ def get_runtime_version(
+ self,
+ ) -> Callable[[cloud_tpu.GetRuntimeVersionRequest], cloud_tpu.RuntimeVersion]:
+ r"""Return a callable for the get runtime version method over gRPC.
+
+ Gets a runtime version.
+
+ Returns:
+ Callable[[~.GetRuntimeVersionRequest],
+ ~.RuntimeVersion]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_runtime_version" not in self._stubs:
+ self._stubs["get_runtime_version"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetRuntimeVersion",
+ request_serializer=cloud_tpu.GetRuntimeVersionRequest.serialize,
+ response_deserializer=cloud_tpu.RuntimeVersion.deserialize,
+ )
+ return self._stubs["get_runtime_version"]
+
+ @property
+ def get_guest_attributes(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetGuestAttributesRequest], cloud_tpu.GetGuestAttributesResponse
+ ]:
+ r"""Return a callable for the get guest attributes method over gRPC.
+
+ Retrieves the guest attributes for the node.
+
+ Returns:
+ Callable[[~.GetGuestAttributesRequest],
+ ~.GetGuestAttributesResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_guest_attributes" not in self._stubs:
+ self._stubs["get_guest_attributes"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetGuestAttributes",
+ request_serializer=cloud_tpu.GetGuestAttributesRequest.serialize,
+ response_deserializer=cloud_tpu.GetGuestAttributesResponse.deserialize,
+ )
+ return self._stubs["get_guest_attributes"]
+
+ def close(self):
+ self.grpc_channel.close()
+
+
+__all__ = ("TpuGrpcTransport",)
diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py
new file mode 100644
index 0000000..9ef622e
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py
@@ -0,0 +1,611 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import warnings
+from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
+
+from google.api_core import gapic_v1 # type: ignore
+from google.api_core import grpc_helpers_async # type: ignore
+from google.api_core import operations_v1 # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+import packaging.version
+
+import grpc # type: ignore
+from grpc.experimental import aio # type: ignore
+
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.longrunning import operations_pb2 # type: ignore
+from .base import TpuTransport, DEFAULT_CLIENT_INFO
+from .grpc import TpuGrpcTransport
+
+
+class TpuGrpcAsyncIOTransport(TpuTransport):
+ """gRPC AsyncIO backend transport for Tpu.
+
+ Manages TPU nodes and other resources
+ TPU API v2alpha1
+
+ This class defines the same methods as the primary client, so the
+ primary client can load the underlying transport implementation
+ and call it.
+
+ It sends protocol buffers over the wire using gRPC (which is built on
+ top of HTTP/2); the ``grpcio`` package must be installed.
+ """
+
+ _grpc_channel: aio.Channel
+ _stubs: Dict[str, Callable] = {}
+
+ @classmethod
+ def create_channel(
+ cls,
+ host: str = "tpu.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ **kwargs,
+ ) -> aio.Channel:
+ """Create and return a gRPC AsyncIO channel object.
+ Args:
+ host (Optional[str]): The host for the channel to use.
+ credentials (Optional[~.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify this application to the service. If
+ none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ kwargs (Optional[dict]): Keyword arguments, which are passed to the
+ channel creation.
+ Returns:
+ aio.Channel: A gRPC AsyncIO channel object.
+ """
+
+ return grpc_helpers_async.create_channel(
+ host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
+ **kwargs,
+ )
+
+ def __init__(
+ self,
+ *,
+ host: str = "tpu.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ channel: aio.Channel = None,
+ api_mtls_endpoint: str = None,
+ client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
+ ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
+ quota_project_id=None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ This argument is ignored if ``channel`` is provided.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ channel (Optional[aio.Channel]): A ``Channel`` instance through
+ which to make calls.
+ api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
+ If provided, it overrides the ``host`` argument and tries to create
+ a mutual TLS channel with client SSL credentials from
+ ``client_cert_source`` or application default SSL credentials.
+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ Deprecated. A callback to provide client SSL certificate bytes and
+ private key bytes, both in PEM format. It is ignored if
+ ``api_mtls_endpoint`` is None.
+ ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
+ for the grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure a mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+ self._grpc_channel = None
+ self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+ self._operations_client = None
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
+
+ if channel:
+ # Ignore credentials if a channel was passed.
+ credentials = False
+ # If a channel was explicitly provided, set it.
+ self._grpc_channel = channel
+ self._ssl_channel_credentials = None
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
+
+ else:
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
+ )
+
+ if not self._grpc_channel:
+ self._grpc_channel = type(self).create_channel(
+ self._host,
+ credentials=self._credentials,
+ credentials_file=credentials_file,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
+ quota_project_id=quota_project_id,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
+
+ @property
+ def grpc_channel(self) -> aio.Channel:
+ """Create the channel designed to connect to this service.
+
+ This property caches on the instance; repeated calls return
+ the same channel.
+ """
+ # Return the channel from cache.
+ return self._grpc_channel
+
+ @property
+ def operations_client(self) -> operations_v1.OperationsAsyncClient:
+ """Create the client designed to process long-running operations.
+
+ This property caches on the instance; repeated calls return the same
+ client.
+ """
+ # Sanity check: Only create a new client if we do not already have one.
+ if self._operations_client is None:
+ self._operations_client = operations_v1.OperationsAsyncClient(
+ self.grpc_channel
+ )
+
+ # Return the client from cache.
+ return self._operations_client
+
+ @property
+ def list_nodes(
+ self,
+ ) -> Callable[[cloud_tpu.ListNodesRequest], Awaitable[cloud_tpu.ListNodesResponse]]:
+ r"""Return a callable for the list nodes method over gRPC.
+
+ Lists nodes.
+
+ Returns:
+ Callable[[~.ListNodesRequest],
+ Awaitable[~.ListNodesResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_nodes" not in self._stubs:
+ self._stubs["list_nodes"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListNodes",
+ request_serializer=cloud_tpu.ListNodesRequest.serialize,
+ response_deserializer=cloud_tpu.ListNodesResponse.deserialize,
+ )
+ return self._stubs["list_nodes"]
+
+ @property
+ def get_node(
+ self,
+ ) -> Callable[[cloud_tpu.GetNodeRequest], Awaitable[cloud_tpu.Node]]:
+ r"""Return a callable for the get node method over gRPC.
+
+ Gets the details of a node.
+
+ Returns:
+ Callable[[~.GetNodeRequest],
+ Awaitable[~.Node]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_node" not in self._stubs:
+ self._stubs["get_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetNode",
+ request_serializer=cloud_tpu.GetNodeRequest.serialize,
+ response_deserializer=cloud_tpu.Node.deserialize,
+ )
+ return self._stubs["get_node"]
+
+ @property
+ def create_node(
+ self,
+ ) -> Callable[[cloud_tpu.CreateNodeRequest], Awaitable[operations_pb2.Operation]]:
+ r"""Return a callable for the create node method over gRPC.
+
+ Creates a node.
+
+ Returns:
+ Callable[[~.CreateNodeRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "create_node" not in self._stubs:
+ self._stubs["create_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/CreateNode",
+ request_serializer=cloud_tpu.CreateNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["create_node"]
+
+ @property
+ def delete_node(
+ self,
+ ) -> Callable[[cloud_tpu.DeleteNodeRequest], Awaitable[operations_pb2.Operation]]:
+ r"""Return a callable for the delete node method over gRPC.
+
+ Deletes a node.
+
+ Returns:
+ Callable[[~.DeleteNodeRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_node" not in self._stubs:
+ self._stubs["delete_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/DeleteNode",
+ request_serializer=cloud_tpu.DeleteNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["delete_node"]
+
+ @property
+ def stop_node(
+ self,
+ ) -> Callable[[cloud_tpu.StopNodeRequest], Awaitable[operations_pb2.Operation]]:
+ r"""Return a callable for the stop node method over gRPC.
+
+ Stops a node. This operation is only available with
+ single TPU nodes.
+
+ Returns:
+ Callable[[~.StopNodeRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "stop_node" not in self._stubs:
+ self._stubs["stop_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/StopNode",
+ request_serializer=cloud_tpu.StopNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["stop_node"]
+
+ @property
+ def start_node(
+ self,
+ ) -> Callable[[cloud_tpu.StartNodeRequest], Awaitable[operations_pb2.Operation]]:
+ r"""Return a callable for the start node method over gRPC.
+
+ Starts a node.
+
+ Returns:
+ Callable[[~.StartNodeRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "start_node" not in self._stubs:
+ self._stubs["start_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/StartNode",
+ request_serializer=cloud_tpu.StartNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["start_node"]
+
+ @property
+ def update_node(
+ self,
+ ) -> Callable[[cloud_tpu.UpdateNodeRequest], Awaitable[operations_pb2.Operation]]:
+ r"""Return a callable for the update node method over gRPC.
+
+ Updates the configurations of a node.
+
+ Returns:
+ Callable[[~.UpdateNodeRequest],
+ Awaitable[~.Operation]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "update_node" not in self._stubs:
+ self._stubs["update_node"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/UpdateNode",
+ request_serializer=cloud_tpu.UpdateNodeRequest.serialize,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["update_node"]
+
+ @property
+ def generate_service_identity(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GenerateServiceIdentityRequest],
+ Awaitable[cloud_tpu.GenerateServiceIdentityResponse],
+ ]:
+ r"""Return a callable for the generate service identity method over gRPC.
+
+ Generates the Cloud TPU service identity for the
+ project.
+
+ Returns:
+ Callable[[~.GenerateServiceIdentityRequest],
+ Awaitable[~.GenerateServiceIdentityResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "generate_service_identity" not in self._stubs:
+ self._stubs["generate_service_identity"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GenerateServiceIdentity",
+ request_serializer=cloud_tpu.GenerateServiceIdentityRequest.serialize,
+ response_deserializer=cloud_tpu.GenerateServiceIdentityResponse.deserialize,
+ )
+ return self._stubs["generate_service_identity"]
+
+ @property
+ def list_accelerator_types(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListAcceleratorTypesRequest],
+ Awaitable[cloud_tpu.ListAcceleratorTypesResponse],
+ ]:
+ r"""Return a callable for the list accelerator types method over gRPC.
+
+ Lists accelerator types supported by this API.
+
+ Returns:
+ Callable[[~.ListAcceleratorTypesRequest],
+ Awaitable[~.ListAcceleratorTypesResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_accelerator_types" not in self._stubs:
+ self._stubs["list_accelerator_types"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListAcceleratorTypes",
+ request_serializer=cloud_tpu.ListAcceleratorTypesRequest.serialize,
+ response_deserializer=cloud_tpu.ListAcceleratorTypesResponse.deserialize,
+ )
+ return self._stubs["list_accelerator_types"]
+
+ @property
+ def get_accelerator_type(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetAcceleratorTypeRequest], Awaitable[cloud_tpu.AcceleratorType]
+ ]:
+ r"""Return a callable for the get accelerator type method over gRPC.
+
+ Gets AcceleratorType.
+
+ Returns:
+ Callable[[~.GetAcceleratorTypeRequest],
+ Awaitable[~.AcceleratorType]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_accelerator_type" not in self._stubs:
+ self._stubs["get_accelerator_type"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetAcceleratorType",
+ request_serializer=cloud_tpu.GetAcceleratorTypeRequest.serialize,
+ response_deserializer=cloud_tpu.AcceleratorType.deserialize,
+ )
+ return self._stubs["get_accelerator_type"]
+
+ @property
+ def list_runtime_versions(
+ self,
+ ) -> Callable[
+ [cloud_tpu.ListRuntimeVersionsRequest],
+ Awaitable[cloud_tpu.ListRuntimeVersionsResponse],
+ ]:
+ r"""Return a callable for the list runtime versions method over gRPC.
+
+ Lists runtime versions supported by this API.
+
+ Returns:
+ Callable[[~.ListRuntimeVersionsRequest],
+ Awaitable[~.ListRuntimeVersionsResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_runtime_versions" not in self._stubs:
+ self._stubs["list_runtime_versions"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/ListRuntimeVersions",
+ request_serializer=cloud_tpu.ListRuntimeVersionsRequest.serialize,
+ response_deserializer=cloud_tpu.ListRuntimeVersionsResponse.deserialize,
+ )
+ return self._stubs["list_runtime_versions"]
+
+ @property
+ def get_runtime_version(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetRuntimeVersionRequest], Awaitable[cloud_tpu.RuntimeVersion]
+ ]:
+ r"""Return a callable for the get runtime version method over gRPC.
+
+ Gets a runtime version.
+
+ Returns:
+ Callable[[~.GetRuntimeVersionRequest],
+ Awaitable[~.RuntimeVersion]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_runtime_version" not in self._stubs:
+ self._stubs["get_runtime_version"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetRuntimeVersion",
+ request_serializer=cloud_tpu.GetRuntimeVersionRequest.serialize,
+ response_deserializer=cloud_tpu.RuntimeVersion.deserialize,
+ )
+ return self._stubs["get_runtime_version"]
+
+ @property
+ def get_guest_attributes(
+ self,
+ ) -> Callable[
+ [cloud_tpu.GetGuestAttributesRequest],
+ Awaitable[cloud_tpu.GetGuestAttributesResponse],
+ ]:
+ r"""Return a callable for the get guest attributes method over gRPC.
+
+ Retrieves the guest attributes for the node.
+
+ Returns:
+ Callable[[~.GetGuestAttributesRequest],
+ Awaitable[~.GetGuestAttributesResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_guest_attributes" not in self._stubs:
+ self._stubs["get_guest_attributes"] = self.grpc_channel.unary_unary(
+ "/google.cloud.tpu.v2alpha1.Tpu/GetGuestAttributes",
+ request_serializer=cloud_tpu.GetGuestAttributesRequest.serialize,
+ response_deserializer=cloud_tpu.GetGuestAttributesResponse.deserialize,
+ )
+ return self._stubs["get_guest_attributes"]
+
+ def close(self):
+ return self.grpc_channel.close()
+
+
+__all__ = ("TpuGrpcAsyncIOTransport",)
diff --git a/google/cloud/tpu_v2alpha1/types/__init__.py b/google/cloud/tpu_v2alpha1/types/__init__.py
new file mode 100644
index 0000000..fd80d1e
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/types/__init__.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from .cloud_tpu import (
+ AcceleratorType,
+ AccessConfig,
+ AttachedDisk,
+ CreateNodeRequest,
+ DeleteNodeRequest,
+ GenerateServiceIdentityRequest,
+ GenerateServiceIdentityResponse,
+ GetAcceleratorTypeRequest,
+ GetGuestAttributesRequest,
+ GetGuestAttributesResponse,
+ GetNodeRequest,
+ GetRuntimeVersionRequest,
+ GuestAttributes,
+ GuestAttributesEntry,
+ GuestAttributesValue,
+ ListAcceleratorTypesRequest,
+ ListAcceleratorTypesResponse,
+ ListNodesRequest,
+ ListNodesResponse,
+ ListRuntimeVersionsRequest,
+ ListRuntimeVersionsResponse,
+ NetworkConfig,
+ NetworkEndpoint,
+ Node,
+ OperationMetadata,
+ RuntimeVersion,
+ SchedulingConfig,
+ ServiceAccount,
+ ServiceIdentity,
+ StartNodeRequest,
+ StopNodeRequest,
+ Symptom,
+ UpdateNodeRequest,
+)
+
+__all__ = (
+ "AcceleratorType",
+ "AccessConfig",
+ "AttachedDisk",
+ "CreateNodeRequest",
+ "DeleteNodeRequest",
+ "GenerateServiceIdentityRequest",
+ "GenerateServiceIdentityResponse",
+ "GetAcceleratorTypeRequest",
+ "GetGuestAttributesRequest",
+ "GetGuestAttributesResponse",
+ "GetNodeRequest",
+ "GetRuntimeVersionRequest",
+ "GuestAttributes",
+ "GuestAttributesEntry",
+ "GuestAttributesValue",
+ "ListAcceleratorTypesRequest",
+ "ListAcceleratorTypesResponse",
+ "ListNodesRequest",
+ "ListNodesResponse",
+ "ListRuntimeVersionsRequest",
+ "ListRuntimeVersionsResponse",
+ "NetworkConfig",
+ "NetworkEndpoint",
+ "Node",
+ "OperationMetadata",
+ "RuntimeVersion",
+ "SchedulingConfig",
+ "ServiceAccount",
+ "ServiceIdentity",
+ "StartNodeRequest",
+ "StopNodeRequest",
+ "Symptom",
+ "UpdateNodeRequest",
+)
diff --git a/google/cloud/tpu_v2alpha1/types/cloud_tpu.py b/google/cloud/tpu_v2alpha1/types/cloud_tpu.py
new file mode 100644
index 0000000..e6d937c
--- /dev/null
+++ b/google/cloud/tpu_v2alpha1/types/cloud_tpu.py
@@ -0,0 +1,766 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import proto # type: ignore
+
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+
+
+__protobuf__ = proto.module(
+ package="google.cloud.tpu.v2alpha1",
+ manifest={
+ "GuestAttributes",
+ "GuestAttributesValue",
+ "GuestAttributesEntry",
+ "AttachedDisk",
+ "SchedulingConfig",
+ "NetworkEndpoint",
+ "AccessConfig",
+ "NetworkConfig",
+ "ServiceAccount",
+ "Node",
+ "ListNodesRequest",
+ "ListNodesResponse",
+ "GetNodeRequest",
+ "CreateNodeRequest",
+ "DeleteNodeRequest",
+ "StopNodeRequest",
+ "StartNodeRequest",
+ "UpdateNodeRequest",
+ "ServiceIdentity",
+ "GenerateServiceIdentityRequest",
+ "GenerateServiceIdentityResponse",
+ "AcceleratorType",
+ "GetAcceleratorTypeRequest",
+ "ListAcceleratorTypesRequest",
+ "ListAcceleratorTypesResponse",
+ "OperationMetadata",
+ "RuntimeVersion",
+ "GetRuntimeVersionRequest",
+ "ListRuntimeVersionsRequest",
+ "ListRuntimeVersionsResponse",
+ "Symptom",
+ "GetGuestAttributesRequest",
+ "GetGuestAttributesResponse",
+ },
+)
+
+
+class GuestAttributes(proto.Message):
+ r"""A guest attributes.
+
+ Attributes:
+ query_path (str):
+ The path to be queried. This can be the
+ default namespace ('/') or a nested namespace
+ ('/\/') or a specified key
+ ('/\/\')
+ query_value (google.cloud.tpu_v2alpha1.types.GuestAttributesValue):
+ The value of the requested queried path.
+ """
+
+ query_path = proto.Field(proto.STRING, number=1,)
+ query_value = proto.Field(proto.MESSAGE, number=2, message="GuestAttributesValue",)
+
+
+class GuestAttributesValue(proto.Message):
+ r"""Array of guest attribute namespace/key/value tuples.
+
+ Attributes:
+ items (Sequence[google.cloud.tpu_v2alpha1.types.GuestAttributesEntry]):
+ The list of guest attributes entries.
+ """
+
+ items = proto.RepeatedField(
+ proto.MESSAGE, number=1, message="GuestAttributesEntry",
+ )
+
+
+class GuestAttributesEntry(proto.Message):
+ r"""A guest attributes namespace/key/value entry.
+
+ Attributes:
+ namespace (str):
+ Namespace for the guest attribute entry.
+ key (str):
+ Key for the guest attribute entry.
+ value (str):
+ Value for the guest attribute entry.
+ """
+
+ namespace = proto.Field(proto.STRING, number=1,)
+ key = proto.Field(proto.STRING, number=2,)
+ value = proto.Field(proto.STRING, number=3,)
+
+
+class AttachedDisk(proto.Message):
+ r"""A node-attached disk resource.
+ Next ID: 8;
+
+ Attributes:
+ source_disk (str):
+ Specifies the full path to an existing disk.
+ For example: "projects/my-project/zones/us-
+ central1-c/disks/my-disk".
+ mode (google.cloud.tpu_v2alpha1.types.AttachedDisk.DiskMode):
+ The mode in which to attach this disk. If not specified, the
+ default is READ_WRITE mode. Only applicable to data_disks.
+ """
+
+ class DiskMode(proto.Enum):
+ r"""The different mode of the attached disk."""
+ DISK_MODE_UNSPECIFIED = 0
+ READ_WRITE = 1
+ READ_ONLY = 2
+
+ source_disk = proto.Field(proto.STRING, number=3,)
+ mode = proto.Field(proto.ENUM, number=4, enum=DiskMode,)
+
+
+class SchedulingConfig(proto.Message):
+ r"""Sets the scheduling options for this node.
+
+ Attributes:
+ preemptible (bool):
+ Defines whether the node is preemptible.
+ reserved (bool):
+ Whether the node is created under a
+ reservation.
+ """
+
+ preemptible = proto.Field(proto.BOOL, number=1,)
+ reserved = proto.Field(proto.BOOL, number=2,)
+
+
+class NetworkEndpoint(proto.Message):
+ r"""A network endpoint over which a TPU worker can be reached.
+
+ Attributes:
+ ip_address (str):
+ The internal IP address of this network
+ endpoint.
+ port (int):
+ The port of this network endpoint.
+ access_config (google.cloud.tpu_v2alpha1.types.AccessConfig):
+ The access config for the TPU worker.
+ """
+
+ ip_address = proto.Field(proto.STRING, number=1,)
+ port = proto.Field(proto.INT32, number=2,)
+ access_config = proto.Field(proto.MESSAGE, number=5, message="AccessConfig",)
+
+
+class AccessConfig(proto.Message):
+ r"""An access config attached to the TPU worker.
+
+ Attributes:
+ external_ip (str):
+ Output only. An external IP address
+ associated with the TPU worker.
+ """
+
+ external_ip = proto.Field(proto.STRING, number=1,)
+
+
+class NetworkConfig(proto.Message):
+ r"""Network related configurations.
+
+ Attributes:
+ network (str):
+ The name of the network for the TPU node. It
+ must be a preexisting Google Compute Engine
+ network. If none is provided, "default" will be
+ used.
+ subnetwork (str):
+ The name of the subnetwork for the TPU node.
+ It must be a preexisting Google Compute Engine
+ subnetwork. If none is provided, "default" will
+ be used.
+ enable_external_ips (bool):
+ Indicates that external IP addresses would be
+ associated with the TPU workers. If set to
+ false, the specified subnetwork or network
+ should have Private Google Access enabled.
+ """
+
+ network = proto.Field(proto.STRING, number=1,)
+ subnetwork = proto.Field(proto.STRING, number=2,)
+ enable_external_ips = proto.Field(proto.BOOL, number=3,)
+
+
+class ServiceAccount(proto.Message):
+ r"""A service account.
+
+ Attributes:
+ email (str):
+ Email address of the service account. If
+ empty, default Compute service account will be
+ used.
+ scope (Sequence[str]):
+ The list of scopes to be made available for
+ this service account. If empty, access to all
+ Cloud APIs will be allowed.
+ """
+
+ email = proto.Field(proto.STRING, number=1,)
+ scope = proto.RepeatedField(proto.STRING, number=2,)
+
+
+class Node(proto.Message):
+ r"""A TPU instance.
+
+ Attributes:
+ name (str):
+ Output only. Immutable. The name of the TPU.
+ description (str):
+ The user-supplied description of the TPU.
+ Maximum of 512 characters.
+ accelerator_type (str):
+ Required. The type of hardware accelerators
+ associated with this node.
+ state (google.cloud.tpu_v2alpha1.types.Node.State):
+ Output only. The current state for the TPU
+ Node.
+ health_description (str):
+ Output only. If this field is populated, it
+ contains a description of why the TPU Node is
+ unhealthy.
+ runtime_version (str):
+ Required. The runtime version running in the
+ Node.
+ network_config (google.cloud.tpu_v2alpha1.types.NetworkConfig):
+ Network configurations for the TPU node.
+ cidr_block (str):
+ The CIDR block that the TPU node will use
+ when selecting an IP address. This CIDR block
+ must be a /29 block; the Compute Engine networks
+ API forbids a smaller block, and using a larger
+ block would be wasteful (a node can only consume
+ one IP address). Errors will occur if the CIDR
+ block has already been used for a currently
+ existing TPU node, the CIDR block conflicts with
+ any subnetworks in the user's provided network,
+ or the provided network is peered with another
+ network that is using that CIDR block.
+ service_account (google.cloud.tpu_v2alpha1.types.ServiceAccount):
+ The Google Cloud Platform Service Account to
+ be used by the TPU node VMs. If None is
+ specified, the default compute service account
+ will be used.
+ create_time (google.protobuf.timestamp_pb2.Timestamp):
+ Output only. The time when the node was
+ created.
+ scheduling_config (google.cloud.tpu_v2alpha1.types.SchedulingConfig):
+ The scheduling options for this node.
+ network_endpoints (Sequence[google.cloud.tpu_v2alpha1.types.NetworkEndpoint]):
+ Output only. The network endpoints where TPU
+ workers can be accessed and sent work. It is
+ recommended that runtime clients of the node
+ reach out to the 0th entry in this map first.
+ health (google.cloud.tpu_v2alpha1.types.Node.Health):
+ The health status of the TPU node.
+ labels (Sequence[google.cloud.tpu_v2alpha1.types.Node.LabelsEntry]):
+ Resource labels to represent user-provided
+ metadata.
+ metadata (Sequence[google.cloud.tpu_v2alpha1.types.Node.MetadataEntry]):
+ Custom metadata to apply to the TPU Node.
+ Can set startup-script and shutdown-script
+ tags (Sequence[str]):
+ Tags to apply to the TPU Node. Tags are used
+ to identify valid sources or targets for network
+ firewalls.
+ id (int):
+ Output only. The unique identifier for the
+ TPU Node.
+ data_disks (Sequence[google.cloud.tpu_v2alpha1.types.AttachedDisk]):
+ The additional data disks for the Node.
+ api_version (google.cloud.tpu_v2alpha1.types.Node.ApiVersion):
+ Output only. The API version that created
+ this Node.
+ symptoms (Sequence[google.cloud.tpu_v2alpha1.types.Symptom]):
+ Output only. The Symptoms that have occurred
+ to the TPU Node.
+ """
+
+ class State(proto.Enum):
+ r"""Represents the different states of a TPU node during its
+ lifecycle.
+ """
+ STATE_UNSPECIFIED = 0
+ CREATING = 1
+ READY = 2
+ RESTARTING = 3
+ REIMAGING = 4
+ DELETING = 5
+ REPAIRING = 6
+ STOPPED = 8
+ STOPPING = 9
+ STARTING = 10
+ PREEMPTED = 11
+ TERMINATED = 12
+ HIDING = 13
+ HIDDEN = 14
+ UNHIDING = 15
+
+ class Health(proto.Enum):
+ r"""Health defines the status of a TPU node as reported by
+ Health Monitor.
+ """
+ HEALTH_UNSPECIFIED = 0
+ HEALTHY = 1
+ TIMEOUT = 3
+ UNHEALTHY_TENSORFLOW = 4
+ UNHEALTHY_MAINTENANCE = 5
+
+ class ApiVersion(proto.Enum):
+ r"""TPU API Version."""
+ API_VERSION_UNSPECIFIED = 0
+ V1_ALPHA1 = 1
+ V1 = 2
+ V2_ALPHA1 = 3
+
+ name = proto.Field(proto.STRING, number=1,)
+ description = proto.Field(proto.STRING, number=3,)
+ accelerator_type = proto.Field(proto.STRING, number=5,)
+ state = proto.Field(proto.ENUM, number=9, enum=State,)
+ health_description = proto.Field(proto.STRING, number=10,)
+ runtime_version = proto.Field(proto.STRING, number=11,)
+ network_config = proto.Field(proto.MESSAGE, number=36, message="NetworkConfig",)
+ cidr_block = proto.Field(proto.STRING, number=13,)
+ service_account = proto.Field(proto.MESSAGE, number=37, message="ServiceAccount",)
+ create_time = proto.Field(
+ proto.MESSAGE, number=16, message=timestamp_pb2.Timestamp,
+ )
+ scheduling_config = proto.Field(
+ proto.MESSAGE, number=17, message="SchedulingConfig",
+ )
+ network_endpoints = proto.RepeatedField(
+ proto.MESSAGE, number=21, message="NetworkEndpoint",
+ )
+ health = proto.Field(proto.ENUM, number=22, enum=Health,)
+ labels = proto.MapField(proto.STRING, proto.STRING, number=24,)
+ metadata = proto.MapField(proto.STRING, proto.STRING, number=34,)
+ tags = proto.RepeatedField(proto.STRING, number=40,)
+ id = proto.Field(proto.INT64, number=33,)
+ data_disks = proto.RepeatedField(proto.MESSAGE, number=41, message="AttachedDisk",)
+ api_version = proto.Field(proto.ENUM, number=38, enum=ApiVersion,)
+ symptoms = proto.RepeatedField(proto.MESSAGE, number=39, message="Symptom",)
+
+
+class ListNodesRequest(proto.Message):
+ r"""Request for [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+
+ Attributes:
+ parent (str):
+ Required. The parent resource name.
+ page_size (int):
+ The maximum number of items to return.
+ page_token (str):
+ The next_page_token value returned from a previous List
+ request, if any.
+ """
+
+ parent = proto.Field(proto.STRING, number=1,)
+ page_size = proto.Field(proto.INT32, number=2,)
+ page_token = proto.Field(proto.STRING, number=3,)
+
+
+class ListNodesResponse(proto.Message):
+ r"""Response for [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes].
+
+ Attributes:
+ nodes (Sequence[google.cloud.tpu_v2alpha1.types.Node]):
+ The listed nodes.
+ next_page_token (str):
+ The next page token or empty if none.
+ unreachable (Sequence[str]):
+ Locations that could not be reached.
+ """
+
+ @property
+ def raw_page(self):
+ return self
+
+ nodes = proto.RepeatedField(proto.MESSAGE, number=1, message="Node",)
+ next_page_token = proto.Field(proto.STRING, number=2,)
+ unreachable = proto.RepeatedField(proto.STRING, number=3,)
+
+
+class GetNodeRequest(proto.Message):
+ r"""Request for [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode].
+
+ Attributes:
+ name (str):
+ Required. The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class CreateNodeRequest(proto.Message):
+ r"""Request for [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode].
+
+ Attributes:
+ parent (str):
+ Required. The parent resource name.
+ node_id (str):
+ The unqualified resource name.
+ node (google.cloud.tpu_v2alpha1.types.Node):
+ Required. The node.
+ """
+
+ parent = proto.Field(proto.STRING, number=1,)
+ node_id = proto.Field(proto.STRING, number=2,)
+ node = proto.Field(proto.MESSAGE, number=3, message="Node",)
+
+
+class DeleteNodeRequest(proto.Message):
+ r"""Request for [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode].
+
+ Attributes:
+ name (str):
+ Required. The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class StopNodeRequest(proto.Message):
+ r"""Request for [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode].
+
+ Attributes:
+ name (str):
+ The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class StartNodeRequest(proto.Message):
+ r"""Request for [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode].
+
+ Attributes:
+ name (str):
+ The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class UpdateNodeRequest(proto.Message):
+ r"""Request for [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode].
+
+ Attributes:
+ update_mask (google.protobuf.field_mask_pb2.FieldMask):
+ Required. Mask of fields from [Node][Tpu.Node] to update.
+ Supported fields: None.
+ node (google.cloud.tpu_v2alpha1.types.Node):
+ Required. The node. Only fields specified in update_mask are
+ updated.
+ """
+
+ update_mask = proto.Field(
+ proto.MESSAGE, number=1, message=field_mask_pb2.FieldMask,
+ )
+ node = proto.Field(proto.MESSAGE, number=2, message="Node",)
+
+
+class ServiceIdentity(proto.Message):
+ r"""The per-product per-project service identity for Cloud TPU
+ service.
+
+ Attributes:
+ email (str):
+ The email address of the service identity.
+ """
+
+ email = proto.Field(proto.STRING, number=1,)
+
+
+class GenerateServiceIdentityRequest(proto.Message):
+ r"""Request for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+
+ Attributes:
+ parent (str):
+ Required. The parent resource name.
+ """
+
+ parent = proto.Field(proto.STRING, number=1,)
+
+
+class GenerateServiceIdentityResponse(proto.Message):
+ r"""Response for
+ [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity].
+
+ Attributes:
+ identity (google.cloud.tpu_v2alpha1.types.ServiceIdentity):
+ ServiceIdentity that was created or
+ retrieved.
+ """
+
+ identity = proto.Field(proto.MESSAGE, number=1, message="ServiceIdentity",)
+
+
+class AcceleratorType(proto.Message):
+ r"""A accelerator type that a Node can be configured with.
+
+ Attributes:
+ name (str):
+ The resource name.
+ type_ (str):
+ the accelerator type.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+ type_ = proto.Field(proto.STRING, number=2,)
+
+
+class GetAcceleratorTypeRequest(proto.Message):
+ r"""Request for
+ [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType].
+
+ Attributes:
+ name (str):
+ Required. The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class ListAcceleratorTypesRequest(proto.Message):
+ r"""Request for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+
+ Attributes:
+ parent (str):
+ Required. The parent resource name.
+ page_size (int):
+ The maximum number of items to return.
+ page_token (str):
+ The next_page_token value returned from a previous List
+ request, if any.
+ filter (str):
+ List filter.
+ order_by (str):
+ Sort results.
+ """
+
+ parent = proto.Field(proto.STRING, number=1,)
+ page_size = proto.Field(proto.INT32, number=2,)
+ page_token = proto.Field(proto.STRING, number=3,)
+ filter = proto.Field(proto.STRING, number=5,)
+ order_by = proto.Field(proto.STRING, number=6,)
+
+
+class ListAcceleratorTypesResponse(proto.Message):
+ r"""Response for
+ [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes].
+
+ Attributes:
+ accelerator_types (Sequence[google.cloud.tpu_v2alpha1.types.AcceleratorType]):
+ The listed nodes.
+ next_page_token (str):
+ The next page token or empty if none.
+ unreachable (Sequence[str]):
+ Locations that could not be reached.
+ """
+
+ @property
+ def raw_page(self):
+ return self
+
+ accelerator_types = proto.RepeatedField(
+ proto.MESSAGE, number=1, message="AcceleratorType",
+ )
+ next_page_token = proto.Field(proto.STRING, number=2,)
+ unreachable = proto.RepeatedField(proto.STRING, number=3,)
+
+
+class OperationMetadata(proto.Message):
+ r"""Metadata describing an [Operation][google.longrunning.Operation]
+
+ Attributes:
+ create_time (google.protobuf.timestamp_pb2.Timestamp):
+ The time the operation was created.
+ end_time (google.protobuf.timestamp_pb2.Timestamp):
+ The time the operation finished running.
+ target (str):
+ Target of the operation - for example
+ projects/project-1/connectivityTests/test-1
+ verb (str):
+ Name of the verb executed by the operation.
+ status_detail (str):
+ Human-readable status of the operation, if
+ any.
+ cancel_requested (bool):
+ Specifies if cancellation was requested for
+ the operation.
+ api_version (str):
+ API version.
+ """
+
+ create_time = proto.Field(proto.MESSAGE, number=1, message=timestamp_pb2.Timestamp,)
+ end_time = proto.Field(proto.MESSAGE, number=2, message=timestamp_pb2.Timestamp,)
+ target = proto.Field(proto.STRING, number=3,)
+ verb = proto.Field(proto.STRING, number=4,)
+ status_detail = proto.Field(proto.STRING, number=5,)
+ cancel_requested = proto.Field(proto.BOOL, number=6,)
+ api_version = proto.Field(proto.STRING, number=7,)
+
+
+class RuntimeVersion(proto.Message):
+ r"""A runtime version that a Node can be configured with.
+
+ Attributes:
+ name (str):
+ The resource name.
+ version (str):
+ The runtime version.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+ version = proto.Field(proto.STRING, number=2,)
+
+
+class GetRuntimeVersionRequest(proto.Message):
+ r"""Request for
+ [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion].
+
+ Attributes:
+ name (str):
+ Required. The resource name.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+
+
+class ListRuntimeVersionsRequest(proto.Message):
+ r"""Request for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+
+ Attributes:
+ parent (str):
+ Required. The parent resource name.
+ page_size (int):
+ The maximum number of items to return.
+ page_token (str):
+ The next_page_token value returned from a previous List
+ request, if any.
+ filter (str):
+ List filter.
+ order_by (str):
+ Sort results.
+ """
+
+ parent = proto.Field(proto.STRING, number=1,)
+ page_size = proto.Field(proto.INT32, number=2,)
+ page_token = proto.Field(proto.STRING, number=3,)
+ filter = proto.Field(proto.STRING, number=5,)
+ order_by = proto.Field(proto.STRING, number=6,)
+
+
+class ListRuntimeVersionsResponse(proto.Message):
+ r"""Response for
+ [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions].
+
+ Attributes:
+ runtime_versions (Sequence[google.cloud.tpu_v2alpha1.types.RuntimeVersion]):
+ The listed nodes.
+ next_page_token (str):
+ The next page token or empty if none.
+ unreachable (Sequence[str]):
+ Locations that could not be reached.
+ """
+
+ @property
+ def raw_page(self):
+ return self
+
+ runtime_versions = proto.RepeatedField(
+ proto.MESSAGE, number=1, message="RuntimeVersion",
+ )
+ next_page_token = proto.Field(proto.STRING, number=2,)
+ unreachable = proto.RepeatedField(proto.STRING, number=3,)
+
+
+class Symptom(proto.Message):
+ r"""A Symptom instance.
+
+ Attributes:
+ create_time (google.protobuf.timestamp_pb2.Timestamp):
+ Timestamp when the Symptom is created.
+ symptom_type (google.cloud.tpu_v2alpha1.types.Symptom.SymptomType):
+ Type of the Symptom.
+ details (str):
+ Detailed information of the current Symptom.
+ worker_id (str):
+ A string used to uniquely distinguish a
+ worker within a TPU node.
+ """
+
+ class SymptomType(proto.Enum):
+ r"""SymptomType represents the different types of Symptoms that a
+ TPU can be at.
+ """
+ SYMPTOM_TYPE_UNSPECIFIED = 0
+ LOW_MEMORY = 1
+ OUT_OF_MEMORY = 2
+ EXECUTE_TIMED_OUT = 3
+ MESH_BUILD_FAIL = 4
+ HBM_OUT_OF_MEMORY = 5
+ PROJECT_ABUSE = 6
+
+ create_time = proto.Field(proto.MESSAGE, number=1, message=timestamp_pb2.Timestamp,)
+ symptom_type = proto.Field(proto.ENUM, number=2, enum=SymptomType,)
+ details = proto.Field(proto.STRING, number=3,)
+ worker_id = proto.Field(proto.STRING, number=4,)
+
+
+class GetGuestAttributesRequest(proto.Message):
+ r"""Request for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+
+ Attributes:
+ name (str):
+ Required. The resource name.
+ query_path (str):
+ The guest attributes path to be queried.
+ worker_ids (Sequence[str]):
+ The 0-based worker ID. If it is empty, all
+ workers' GuestAttributes will be returned.
+ """
+
+ name = proto.Field(proto.STRING, number=1,)
+ query_path = proto.Field(proto.STRING, number=2,)
+ worker_ids = proto.RepeatedField(proto.STRING, number=3,)
+
+
+class GetGuestAttributesResponse(proto.Message):
+ r"""Response for
+ [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes].
+
+ Attributes:
+ guest_attributes (Sequence[google.cloud.tpu_v2alpha1.types.GuestAttributes]):
+ The guest attributes for the TPU workers.
+ """
+
+ guest_attributes = proto.RepeatedField(
+ proto.MESSAGE, number=1, message="GuestAttributes",
+ )
+
+
+__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/scripts/fixup_tpu_v2alpha1_keywords.py b/scripts/fixup_tpu_v2alpha1_keywords.py
new file mode 100644
index 0000000..10a2b39
--- /dev/null
+++ b/scripts/fixup_tpu_v2alpha1_keywords.py
@@ -0,0 +1,188 @@
+#! /usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import argparse
+import os
+import libcst as cst
+import pathlib
+import sys
+from typing import (Any, Callable, Dict, List, Sequence, Tuple)
+
+
+def partition(
+ predicate: Callable[[Any], bool],
+ iterator: Sequence[Any]
+) -> Tuple[List[Any], List[Any]]:
+ """A stable, out-of-place partition."""
+ results = ([], [])
+
+ for i in iterator:
+ results[int(predicate(i))].append(i)
+
+ # Returns trueList, falseList
+ return results[1], results[0]
+
+
+class tpuCallTransformer(cst.CSTTransformer):
+ CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata')
+ METHOD_TO_PARAMS: Dict[str, Tuple[str]] = {
+ 'create_node': ('parent', 'node', 'node_id', ),
+ 'delete_node': ('name', ),
+ 'generate_service_identity': ('parent', ),
+ 'get_accelerator_type': ('name', ),
+ 'get_guest_attributes': ('name', 'query_path', 'worker_ids', ),
+ 'get_node': ('name', ),
+ 'get_runtime_version': ('name', ),
+ 'list_accelerator_types': ('parent', 'page_size', 'page_token', 'filter', 'order_by', ),
+ 'list_nodes': ('parent', 'page_size', 'page_token', ),
+ 'list_runtime_versions': ('parent', 'page_size', 'page_token', 'filter', 'order_by', ),
+ 'start_node': ('name', ),
+ 'stop_node': ('name', ),
+ 'update_node': ('update_mask', 'node', ),
+ }
+
+ def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode:
+ try:
+ key = original.func.attr.value
+ kword_params = self.METHOD_TO_PARAMS[key]
+ except (AttributeError, KeyError):
+ # Either not a method from the API or too convoluted to be sure.
+ return updated
+
+ # If the existing code is valid, keyword args come after positional args.
+ # Therefore, all positional args must map to the first parameters.
+ args, kwargs = partition(lambda a: not bool(a.keyword), updated.args)
+ if any(k.keyword.value == "request" for k in kwargs):
+ # We've already fixed this file, don't fix it again.
+ return updated
+
+ kwargs, ctrl_kwargs = partition(
+ lambda a: a.keyword.value not in self.CTRL_PARAMS,
+ kwargs
+ )
+
+ args, ctrl_args = args[:len(kword_params)], args[len(kword_params):]
+ ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl))
+ for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS))
+
+ request_arg = cst.Arg(
+ value=cst.Dict([
+ cst.DictElement(
+ cst.SimpleString("'{}'".format(name)),
+cst.Element(value=arg.value)
+ )
+ # Note: the args + kwargs looks silly, but keep in mind that
+ # the control parameters had to be stripped out, and that
+ # those could have been passed positionally or by keyword.
+ for name, arg in zip(kword_params, args + kwargs)]),
+ keyword=cst.Name("request")
+ )
+
+ return updated.with_changes(
+ args=[request_arg] + ctrl_kwargs
+ )
+
+
+def fix_files(
+ in_dir: pathlib.Path,
+ out_dir: pathlib.Path,
+ *,
+ transformer=tpuCallTransformer(),
+):
+ """Duplicate the input dir to the output dir, fixing file method calls.
+
+ Preconditions:
+ * in_dir is a real directory
+ * out_dir is a real, empty directory
+ """
+ pyfile_gen = (
+ pathlib.Path(os.path.join(root, f))
+ for root, _, files in os.walk(in_dir)
+ for f in files if os.path.splitext(f)[1] == ".py"
+ )
+
+ for fpath in pyfile_gen:
+ with open(fpath, 'r') as f:
+ src = f.read()
+
+ # Parse the code and insert method call fixes.
+ tree = cst.parse_module(src)
+ updated = tree.visit(transformer)
+
+ # Create the path and directory structure for the new file.
+ updated_path = out_dir.joinpath(fpath.relative_to(in_dir))
+ updated_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Generate the updated source file at the corresponding path.
+ with open(updated_path, 'w') as f:
+ f.write(updated.code)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="""Fix up source that uses the tpu client library.
+
+The existing sources are NOT overwritten but are copied to output_dir with changes made.
+
+Note: This tool operates at a best-effort level at converting positional
+ parameters in client method calls to keyword based parameters.
+ Cases where it WILL FAIL include
+ A) * or ** expansion in a method call.
+ B) Calls via function or method alias (includes free function calls)
+ C) Indirect or dispatched calls (e.g. the method is looked up dynamically)
+
+ These all constitute false negatives. The tool will also detect false
+ positives when an API method shares a name with another method.
+""")
+ parser.add_argument(
+ '-d',
+ '--input-directory',
+ required=True,
+ dest='input_dir',
+ help='the input directory to walk for python files to fix up',
+ )
+ parser.add_argument(
+ '-o',
+ '--output-directory',
+ required=True,
+ dest='output_dir',
+ help='the directory to output files fixed via un-flattening',
+ )
+ args = parser.parse_args()
+ input_dir = pathlib.Path(args.input_dir)
+ output_dir = pathlib.Path(args.output_dir)
+ if not input_dir.is_dir():
+ print(
+ f"input directory '{input_dir}' does not exist or is not a directory",
+ file=sys.stderr,
+ )
+ sys.exit(-1)
+
+ if not output_dir.is_dir():
+ print(
+ f"output directory '{output_dir}' does not exist or is not a directory",
+ file=sys.stderr,
+ )
+ sys.exit(-1)
+
+ if os.listdir(output_dir):
+ print(
+ f"output directory '{output_dir}' is not empty",
+ file=sys.stderr,
+ )
+ sys.exit(-1)
+
+ fix_files(input_dir, output_dir)
diff --git a/tests/unit/gapic/tpu_v2alpha1/__init__.py b/tests/unit/gapic/tpu_v2alpha1/__init__.py
new file mode 100644
index 0000000..4de6597
--- /dev/null
+++ b/tests/unit/gapic/tpu_v2alpha1/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/tests/unit/gapic/tpu_v2alpha1/test_tpu.py b/tests/unit/gapic/tpu_v2alpha1/test_tpu.py
new file mode 100644
index 0000000..d6f8b7b
--- /dev/null
+++ b/tests/unit/gapic/tpu_v2alpha1/test_tpu.py
@@ -0,0 +1,4026 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import mock
+import packaging.version
+
+import grpc
+from grpc.experimental import aio
+import math
+import pytest
+from proto.marshal.rules.dates import DurationRule, TimestampRule
+
+
+from google.api_core import client_options
+from google.api_core import exceptions as core_exceptions
+from google.api_core import future
+from google.api_core import gapic_v1
+from google.api_core import grpc_helpers
+from google.api_core import grpc_helpers_async
+from google.api_core import operation_async # type: ignore
+from google.api_core import operations_v1
+from google.api_core import path_template
+from google.auth import credentials as ga_credentials
+from google.auth.exceptions import MutualTLSChannelError
+from google.cloud.tpu_v2alpha1.services.tpu import TpuAsyncClient
+from google.cloud.tpu_v2alpha1.services.tpu import TpuClient
+from google.cloud.tpu_v2alpha1.services.tpu import pagers
+from google.cloud.tpu_v2alpha1.services.tpu import transports
+from google.cloud.tpu_v2alpha1.services.tpu.transports.base import _GOOGLE_AUTH_VERSION
+from google.cloud.tpu_v2alpha1.types import cloud_tpu
+from google.longrunning import operations_pb2
+from google.oauth2 import service_account
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+import google.auth
+
+
+# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively
+# through google-api-core:
+# - Delete the auth "less than" test cases
+# - Delete these pytest markers (Make the "greater than or equal to" tests the default).
+requires_google_auth_lt_1_25_0 = pytest.mark.skipif(
+ packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"),
+ reason="This test requires google-auth < 1.25.0",
+)
+requires_google_auth_gte_1_25_0 = pytest.mark.skipif(
+ packaging.version.parse(_GOOGLE_AUTH_VERSION) < packaging.version.parse("1.25.0"),
+ reason="This test requires google-auth >= 1.25.0",
+)
+
+
+def client_cert_source_callback():
+ return b"cert bytes", b"key bytes"
+
+
+# If default endpoint is localhost, then default mtls endpoint will be the same.
+# This method modifies the default endpoint so the client can produce a different
+# mtls endpoint for endpoint testing purposes.
+def modify_default_endpoint(client):
+ return (
+ "foo.googleapis.com"
+ if ("localhost" in client.DEFAULT_ENDPOINT)
+ else client.DEFAULT_ENDPOINT
+ )
+
+
+def test__get_default_mtls_endpoint():
+ api_endpoint = "example.googleapis.com"
+ api_mtls_endpoint = "example.mtls.googleapis.com"
+ sandbox_endpoint = "example.sandbox.googleapis.com"
+ sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com"
+ non_googleapi = "api.example.com"
+
+ assert TpuClient._get_default_mtls_endpoint(None) is None
+ assert TpuClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint
+ assert TpuClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint
+ assert (
+ TpuClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint
+ )
+ assert (
+ TpuClient._get_default_mtls_endpoint(sandbox_mtls_endpoint)
+ == sandbox_mtls_endpoint
+ )
+ assert TpuClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi
+
+
+@pytest.mark.parametrize("client_class", [TpuClient, TpuAsyncClient,])
+def test_tpu_client_from_service_account_info(client_class):
+ creds = ga_credentials.AnonymousCredentials()
+ with mock.patch.object(
+ service_account.Credentials, "from_service_account_info"
+ ) as factory:
+ factory.return_value = creds
+ info = {"valid": True}
+ client = client_class.from_service_account_info(info)
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ assert client.transport._host == "tpu.googleapis.com:443"
+
+
+@pytest.mark.parametrize(
+ "transport_class,transport_name",
+ [
+ (transports.TpuGrpcTransport, "grpc"),
+ (transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"),
+ ],
+)
+def test_tpu_client_service_account_always_use_jwt(transport_class, transport_name):
+ with mock.patch.object(
+ service_account.Credentials, "with_always_use_jwt_access", create=True
+ ) as use_jwt:
+ creds = service_account.Credentials(None, None, None)
+ transport = transport_class(credentials=creds, always_use_jwt_access=True)
+ use_jwt.assert_called_once_with(True)
+
+ with mock.patch.object(
+ service_account.Credentials, "with_always_use_jwt_access", create=True
+ ) as use_jwt:
+ creds = service_account.Credentials(None, None, None)
+ transport = transport_class(credentials=creds, always_use_jwt_access=False)
+ use_jwt.assert_not_called()
+
+
+@pytest.mark.parametrize("client_class", [TpuClient, TpuAsyncClient,])
+def test_tpu_client_from_service_account_file(client_class):
+ creds = ga_credentials.AnonymousCredentials()
+ with mock.patch.object(
+ service_account.Credentials, "from_service_account_file"
+ ) as factory:
+ factory.return_value = creds
+ client = client_class.from_service_account_file("dummy/file/path.json")
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ client = client_class.from_service_account_json("dummy/file/path.json")
+ assert client.transport._credentials == creds
+ assert isinstance(client, client_class)
+
+ assert client.transport._host == "tpu.googleapis.com:443"
+
+
+def test_tpu_client_get_transport_class():
+ transport = TpuClient.get_transport_class()
+ available_transports = [
+ transports.TpuGrpcTransport,
+ ]
+ assert transport in available_transports
+
+ transport = TpuClient.get_transport_class("grpc")
+ assert transport == transports.TpuGrpcTransport
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (TpuClient, transports.TpuGrpcTransport, "grpc"),
+ (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"),
+ ],
+)
+@mock.patch.object(TpuClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuClient))
+@mock.patch.object(
+ TpuAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuAsyncClient)
+)
+def test_tpu_client_client_options(client_class, transport_class, transport_name):
+ # Check that if channel is provided we won't create a new one.
+ with mock.patch.object(TpuClient, "get_transport_class") as gtc:
+ transport = transport_class(credentials=ga_credentials.AnonymousCredentials())
+ client = client_class(transport=transport)
+ gtc.assert_not_called()
+
+ # Check that if channel is provided via str we will create a new one.
+ with mock.patch.object(TpuClient, "get_transport_class") as gtc:
+ client = client_class(transport=transport_name)
+ gtc.assert_called()
+
+ # Check the case api_endpoint is provided.
+ options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host="squid.clam.whelk",
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
+ # "never".
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is
+ # "always".
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_MTLS_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has
+ # unsupported value.
+ with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}):
+ with pytest.raises(MutualTLSChannelError):
+ client = client_class()
+
+ # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}
+ ):
+ with pytest.raises(ValueError):
+ client = client_class()
+
+ # Check the case quota_project_id is provided
+ options = client_options.ClientOptions(quota_project_id="octopus")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id="octopus",
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name,use_client_cert_env",
+ [
+ (TpuClient, transports.TpuGrpcTransport, "grpc", "true"),
+ (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio", "true"),
+ (TpuClient, transports.TpuGrpcTransport, "grpc", "false"),
+ (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio", "false"),
+ ],
+)
+@mock.patch.object(TpuClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuClient))
+@mock.patch.object(
+ TpuAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuAsyncClient)
+)
+@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"})
+def test_tpu_client_mtls_env_auto(
+ client_class, transport_class, transport_name, use_client_cert_env
+):
+ # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default
+ # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists.
+
+ # Check the case client_cert_source is provided. Whether client cert is used depends on
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ options = client_options.ClientOptions(
+ client_cert_source=client_cert_source_callback
+ )
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+
+ if use_client_cert_env == "false":
+ expected_client_cert_source = None
+ expected_host = client.DEFAULT_ENDPOINT
+ else:
+ expected_client_cert_source = client_cert_source_callback
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case ADC client cert is provided. Whether client cert is used depends on
+ # GOOGLE_API_USE_CLIENT_CERTIFICATE value.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=True,
+ ):
+ with mock.patch(
+ "google.auth.transport.mtls.default_client_cert_source",
+ return_value=client_cert_source_callback,
+ ):
+ if use_client_cert_env == "false":
+ expected_host = client.DEFAULT_ENDPOINT
+ expected_client_cert_source = None
+ else:
+ expected_host = client.DEFAULT_MTLS_ENDPOINT
+ expected_client_cert_source = client_cert_source_callback
+
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=expected_host,
+ scopes=None,
+ client_cert_source_for_mtls=expected_client_cert_source,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+ # Check the case client_cert_source and ADC client cert are not provided.
+ with mock.patch.dict(
+ os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}
+ ):
+ with mock.patch.object(transport_class, "__init__") as patched:
+ with mock.patch(
+ "google.auth.transport.mtls.has_default_client_cert_source",
+ return_value=False,
+ ):
+ patched.return_value = None
+ client = client_class()
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (TpuClient, transports.TpuGrpcTransport, "grpc"),
+ (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"),
+ ],
+)
+def test_tpu_client_client_options_scopes(
+ client_class, transport_class, transport_name
+):
+ # Check the case scopes are provided.
+ options = client_options.ClientOptions(scopes=["1", "2"],)
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host=client.DEFAULT_ENDPOINT,
+ scopes=["1", "2"],
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+@pytest.mark.parametrize(
+ "client_class,transport_class,transport_name",
+ [
+ (TpuClient, transports.TpuGrpcTransport, "grpc"),
+ (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"),
+ ],
+)
+def test_tpu_client_client_options_credentials_file(
+ client_class, transport_class, transport_name
+):
+ # Check the case credentials file is provided.
+ options = client_options.ClientOptions(credentials_file="credentials.json")
+ with mock.patch.object(transport_class, "__init__") as patched:
+ patched.return_value = None
+ client = client_class(client_options=options)
+ patched.assert_called_once_with(
+ credentials=None,
+ credentials_file="credentials.json",
+ host=client.DEFAULT_ENDPOINT,
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+def test_tpu_client_client_options_from_dict():
+ with mock.patch(
+ "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuGrpcTransport.__init__"
+ ) as grpc_transport:
+ grpc_transport.return_value = None
+ client = TpuClient(client_options={"api_endpoint": "squid.clam.whelk"})
+ grpc_transport.assert_called_once_with(
+ credentials=None,
+ credentials_file=None,
+ host="squid.clam.whelk",
+ scopes=None,
+ client_cert_source_for_mtls=None,
+ quota_project_id=None,
+ client_info=transports.base.DEFAULT_CLIENT_INFO,
+ always_use_jwt_access=True,
+ )
+
+
+def test_list_nodes(transport: str = "grpc", request_type=cloud_tpu.ListNodesRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListNodesResponse(
+ next_page_token="next_page_token_value", unreachable=["unreachable_value"],
+ )
+ response = client.list_nodes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListNodesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListNodesPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+def test_list_nodes_from_dict():
+ test_list_nodes(request_type=dict)
+
+
+def test_list_nodes_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ client.list_nodes()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListNodesRequest()
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.ListNodesRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListNodesResponse(
+ next_page_token="next_page_token_value",
+ unreachable=["unreachable_value"],
+ )
+ )
+ response = await client.list_nodes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListNodesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListNodesAsyncPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_async_from_dict():
+ await test_list_nodes_async(request_type=dict)
+
+
+def test_list_nodes_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListNodesRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ call.return_value = cloud_tpu.ListNodesResponse()
+ client.list_nodes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListNodesRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListNodesResponse()
+ )
+ await client.list_nodes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+def test_list_nodes_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListNodesResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.list_nodes(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+def test_list_nodes_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.list_nodes(
+ cloud_tpu.ListNodesRequest(), parent="parent_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListNodesResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListNodesResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.list_nodes(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.list_nodes(
+ cloud_tpu.ListNodesRequest(), parent="parent_value",
+ )
+
+
+def test_list_nodes_pager():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",),
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],),
+ RuntimeError,
+ )
+
+ metadata = ()
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ )
+ pager = client.list_nodes(request={})
+
+ assert pager._metadata == metadata
+
+ results = [i for i in pager]
+ assert len(results) == 6
+ assert all(isinstance(i, cloud_tpu.Node) for i in results)
+
+
+def test_list_nodes_pages():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.list_nodes), "__call__") as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",),
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],),
+ RuntimeError,
+ )
+ pages = list(client.list_nodes(request={}).pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_async_pager():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_nodes), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",),
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],),
+ RuntimeError,
+ )
+ async_pager = await client.list_nodes(request={},)
+ assert async_pager.next_page_token == "abc"
+ responses = []
+ async for response in async_pager:
+ responses.append(response)
+
+ assert len(responses) == 6
+ assert all(isinstance(i, cloud_tpu.Node) for i in responses)
+
+
+@pytest.mark.asyncio
+async def test_list_nodes_async_pages():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_nodes), "__call__", new_callable=mock.AsyncMock
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",),
+ cloud_tpu.ListNodesResponse(
+ nodes=[cloud_tpu.Node(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],),
+ RuntimeError,
+ )
+ pages = []
+ async for page_ in (await client.list_nodes(request={})).pages:
+ pages.append(page_)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+def test_get_node(transport: str = "grpc", request_type=cloud_tpu.GetNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.Node(
+ name="name_value",
+ description="description_value",
+ accelerator_type="accelerator_type_value",
+ state=cloud_tpu.Node.State.CREATING,
+ health_description="health_description_value",
+ runtime_version="runtime_version_value",
+ cidr_block="cidr_block_value",
+ health=cloud_tpu.Node.Health.HEALTHY,
+ tags=["tags_value"],
+ id=205,
+ api_version=cloud_tpu.Node.ApiVersion.V1_ALPHA1,
+ )
+ response = client.get_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.Node)
+ assert response.name == "name_value"
+ assert response.description == "description_value"
+ assert response.accelerator_type == "accelerator_type_value"
+ assert response.state == cloud_tpu.Node.State.CREATING
+ assert response.health_description == "health_description_value"
+ assert response.runtime_version == "runtime_version_value"
+ assert response.cidr_block == "cidr_block_value"
+ assert response.health == cloud_tpu.Node.Health.HEALTHY
+ assert response.tags == ["tags_value"]
+ assert response.id == 205
+ assert response.api_version == cloud_tpu.Node.ApiVersion.V1_ALPHA1
+
+
+def test_get_node_from_dict():
+ test_get_node(request_type=dict)
+
+
+def test_get_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ client.get_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_get_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.GetNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.Node(
+ name="name_value",
+ description="description_value",
+ accelerator_type="accelerator_type_value",
+ state=cloud_tpu.Node.State.CREATING,
+ health_description="health_description_value",
+ runtime_version="runtime_version_value",
+ cidr_block="cidr_block_value",
+ health=cloud_tpu.Node.Health.HEALTHY,
+ tags=["tags_value"],
+ id=205,
+ api_version=cloud_tpu.Node.ApiVersion.V1_ALPHA1,
+ )
+ )
+ response = await client.get_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.Node)
+ assert response.name == "name_value"
+ assert response.description == "description_value"
+ assert response.accelerator_type == "accelerator_type_value"
+ assert response.state == cloud_tpu.Node.State.CREATING
+ assert response.health_description == "health_description_value"
+ assert response.runtime_version == "runtime_version_value"
+ assert response.cidr_block == "cidr_block_value"
+ assert response.health == cloud_tpu.Node.Health.HEALTHY
+ assert response.tags == ["tags_value"]
+ assert response.id == 205
+ assert response.api_version == cloud_tpu.Node.ApiVersion.V1_ALPHA1
+
+
+@pytest.mark.asyncio
+async def test_get_node_async_from_dict():
+ await test_get_node_async(request_type=dict)
+
+
+def test_get_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ call.return_value = cloud_tpu.Node()
+ client.get_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_get_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(cloud_tpu.Node())
+ await client.get_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_get_node_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.Node()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.get_node(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+def test_get_node_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.get_node(
+ cloud_tpu.GetNodeRequest(), name="name_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_node_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.get_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.Node()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(cloud_tpu.Node())
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.get_node(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+@pytest.mark.asyncio
+async def test_get_node_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.get_node(
+ cloud_tpu.GetNodeRequest(), name="name_value",
+ )
+
+
+def test_create_node(transport: str = "grpc", request_type=cloud_tpu.CreateNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.create_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.CreateNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_create_node_from_dict():
+ test_create_node(request_type=dict)
+
+
+def test_create_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ client.create_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.CreateNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_create_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.CreateNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.create_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.CreateNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_create_node_async_from_dict():
+ await test_create_node_async(request_type=dict)
+
+
+def test_create_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.CreateNodeRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.create_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_create_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.CreateNodeRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.create_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+def test_create_node_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.create_node(
+ parent="parent_value",
+ node=cloud_tpu.Node(name="name_value"),
+ node_id="node_id_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+ assert args[0].node == cloud_tpu.Node(name="name_value")
+ assert args[0].node_id == "node_id_value"
+
+
+def test_create_node_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.create_node(
+ cloud_tpu.CreateNodeRequest(),
+ parent="parent_value",
+ node=cloud_tpu.Node(name="name_value"),
+ node_id="node_id_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_create_node_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.create_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.create_node(
+ parent="parent_value",
+ node=cloud_tpu.Node(name="name_value"),
+ node_id="node_id_value",
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+ assert args[0].node == cloud_tpu.Node(name="name_value")
+ assert args[0].node_id == "node_id_value"
+
+
+@pytest.mark.asyncio
+async def test_create_node_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.create_node(
+ cloud_tpu.CreateNodeRequest(),
+ parent="parent_value",
+ node=cloud_tpu.Node(name="name_value"),
+ node_id="node_id_value",
+ )
+
+
+def test_delete_node(transport: str = "grpc", request_type=cloud_tpu.DeleteNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.delete_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.DeleteNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_delete_node_from_dict():
+ test_delete_node(request_type=dict)
+
+
+def test_delete_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ client.delete_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.DeleteNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_delete_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.DeleteNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.delete_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.DeleteNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_delete_node_async_from_dict():
+ await test_delete_node_async(request_type=dict)
+
+
+def test_delete_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.DeleteNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.delete_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_delete_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.DeleteNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.delete_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_delete_node_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.delete_node(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+def test_delete_node_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.delete_node(
+ cloud_tpu.DeleteNodeRequest(), name="name_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_delete_node_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.delete_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.delete_node(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+@pytest.mark.asyncio
+async def test_delete_node_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.delete_node(
+ cloud_tpu.DeleteNodeRequest(), name="name_value",
+ )
+
+
+def test_stop_node(transport: str = "grpc", request_type=cloud_tpu.StopNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.stop_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.stop_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StopNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_stop_node_from_dict():
+ test_stop_node(request_type=dict)
+
+
+def test_stop_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.stop_node), "__call__") as call:
+ client.stop_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StopNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_stop_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.StopNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.stop_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.stop_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StopNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_stop_node_async_from_dict():
+ await test_stop_node_async(request_type=dict)
+
+
+def test_stop_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.StopNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.stop_node), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.stop_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_stop_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.StopNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.stop_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.stop_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_start_node(transport: str = "grpc", request_type=cloud_tpu.StartNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.start_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.start_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StartNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_start_node_from_dict():
+ test_start_node(request_type=dict)
+
+
+def test_start_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.start_node), "__call__") as call:
+ client.start_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StartNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_start_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.StartNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.start_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.start_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.StartNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_start_node_async_from_dict():
+ await test_start_node_async(request_type=dict)
+
+
+def test_start_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.StartNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.start_node), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.start_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_start_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.StartNodeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.start_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.start_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_update_node(transport: str = "grpc", request_type=cloud_tpu.UpdateNodeRequest):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/spam")
+ response = client.update_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.UpdateNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+def test_update_node_from_dict():
+ test_update_node(request_type=dict)
+
+
+def test_update_node_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ client.update_node()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.UpdateNodeRequest()
+
+
+@pytest.mark.asyncio
+async def test_update_node_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.UpdateNodeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ response = await client.update_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.UpdateNodeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, future.Future)
+
+
+@pytest.mark.asyncio
+async def test_update_node_async_from_dict():
+ await test_update_node_async(request_type=dict)
+
+
+def test_update_node_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.UpdateNodeRequest()
+
+ request.node.name = "node.name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ client.update_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "node.name=node.name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_update_node_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.UpdateNodeRequest()
+
+ request.node.name = "node.name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/op")
+ )
+ await client.update_node(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "node.name=node.name/value",) in kw["metadata"]
+
+
+def test_update_node_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.update_node(
+ node=cloud_tpu.Node(name="name_value"),
+ update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]),
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].node == cloud_tpu.Node(name="name_value")
+ assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"])
+
+
+def test_update_node_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.update_node(
+ cloud_tpu.UpdateNodeRequest(),
+ node=cloud_tpu.Node(name="name_value"),
+ update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]),
+ )
+
+
+@pytest.mark.asyncio
+async def test_update_node_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(type(client.transport.update_node), "__call__") as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = operations_pb2.Operation(name="operations/op")
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ operations_pb2.Operation(name="operations/spam")
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.update_node(
+ node=cloud_tpu.Node(name="name_value"),
+ update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]),
+ )
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].node == cloud_tpu.Node(name="name_value")
+ assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"])
+
+
+@pytest.mark.asyncio
+async def test_update_node_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.update_node(
+ cloud_tpu.UpdateNodeRequest(),
+ node=cloud_tpu.Node(name="name_value"),
+ update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]),
+ )
+
+
+def test_generate_service_identity(
+ transport: str = "grpc", request_type=cloud_tpu.GenerateServiceIdentityRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.generate_service_identity), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.GenerateServiceIdentityResponse()
+ response = client.generate_service_identity(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GenerateServiceIdentityRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.GenerateServiceIdentityResponse)
+
+
+def test_generate_service_identity_from_dict():
+ test_generate_service_identity(request_type=dict)
+
+
+def test_generate_service_identity_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.generate_service_identity), "__call__"
+ ) as call:
+ client.generate_service_identity()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GenerateServiceIdentityRequest()
+
+
+@pytest.mark.asyncio
+async def test_generate_service_identity_async(
+ transport: str = "grpc_asyncio",
+ request_type=cloud_tpu.GenerateServiceIdentityRequest,
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.generate_service_identity), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.GenerateServiceIdentityResponse()
+ )
+ response = await client.generate_service_identity(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GenerateServiceIdentityRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.GenerateServiceIdentityResponse)
+
+
+@pytest.mark.asyncio
+async def test_generate_service_identity_async_from_dict():
+ await test_generate_service_identity_async(request_type=dict)
+
+
+def test_generate_service_identity_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GenerateServiceIdentityRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.generate_service_identity), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.GenerateServiceIdentityResponse()
+ client.generate_service_identity(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_generate_service_identity_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GenerateServiceIdentityRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.generate_service_identity), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.GenerateServiceIdentityResponse()
+ )
+ await client.generate_service_identity(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+def test_list_accelerator_types(
+ transport: str = "grpc", request_type=cloud_tpu.ListAcceleratorTypesRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListAcceleratorTypesResponse(
+ next_page_token="next_page_token_value", unreachable=["unreachable_value"],
+ )
+ response = client.list_accelerator_types(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListAcceleratorTypesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListAcceleratorTypesPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+def test_list_accelerator_types_from_dict():
+ test_list_accelerator_types(request_type=dict)
+
+
+def test_list_accelerator_types_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ client.list_accelerator_types()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListAcceleratorTypesRequest()
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.ListAcceleratorTypesRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListAcceleratorTypesResponse(
+ next_page_token="next_page_token_value",
+ unreachable=["unreachable_value"],
+ )
+ )
+ response = await client.list_accelerator_types(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListAcceleratorTypesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListAcceleratorTypesAsyncPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_async_from_dict():
+ await test_list_accelerator_types_async(request_type=dict)
+
+
+def test_list_accelerator_types_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListAcceleratorTypesRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.ListAcceleratorTypesResponse()
+ client.list_accelerator_types(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListAcceleratorTypesRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListAcceleratorTypesResponse()
+ )
+ await client.list_accelerator_types(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+def test_list_accelerator_types_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListAcceleratorTypesResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.list_accelerator_types(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+def test_list_accelerator_types_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.list_accelerator_types(
+ cloud_tpu.ListAcceleratorTypesRequest(), parent="parent_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListAcceleratorTypesResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListAcceleratorTypesResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.list_accelerator_types(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.list_accelerator_types(
+ cloud_tpu.ListAcceleratorTypesRequest(), parent="parent_value",
+ )
+
+
+def test_list_accelerator_types_pager():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[], next_page_token="def",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ ),
+ RuntimeError,
+ )
+
+ metadata = ()
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ )
+ pager = client.list_accelerator_types(request={})
+
+ assert pager._metadata == metadata
+
+ results = [i for i in pager]
+ assert len(results) == 6
+ assert all(isinstance(i, cloud_tpu.AcceleratorType) for i in results)
+
+
+def test_list_accelerator_types_pages():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types), "__call__"
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[], next_page_token="def",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = list(client.list_accelerator_types(request={}).pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_async_pager():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types),
+ "__call__",
+ new_callable=mock.AsyncMock,
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[], next_page_token="def",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ ),
+ RuntimeError,
+ )
+ async_pager = await client.list_accelerator_types(request={},)
+ assert async_pager.next_page_token == "abc"
+ responses = []
+ async for response in async_pager:
+ responses.append(response)
+
+ assert len(responses) == 6
+ assert all(isinstance(i, cloud_tpu.AcceleratorType) for i in responses)
+
+
+@pytest.mark.asyncio
+async def test_list_accelerator_types_async_pages():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_accelerator_types),
+ "__call__",
+ new_callable=mock.AsyncMock,
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[], next_page_token="def",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListAcceleratorTypesResponse(
+ accelerator_types=[
+ cloud_tpu.AcceleratorType(),
+ cloud_tpu.AcceleratorType(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = []
+ async for page_ in (await client.list_accelerator_types(request={})).pages:
+ pages.append(page_)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+def test_get_accelerator_type(
+ transport: str = "grpc", request_type=cloud_tpu.GetAcceleratorTypeRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.AcceleratorType(
+ name="name_value", type_="type__value",
+ )
+ response = client.get_accelerator_type(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetAcceleratorTypeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.AcceleratorType)
+ assert response.name == "name_value"
+ assert response.type_ == "type__value"
+
+
+def test_get_accelerator_type_from_dict():
+ test_get_accelerator_type(request_type=dict)
+
+
+def test_get_accelerator_type_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ client.get_accelerator_type()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetAcceleratorTypeRequest()
+
+
+@pytest.mark.asyncio
+async def test_get_accelerator_type_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.GetAcceleratorTypeRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.AcceleratorType(name="name_value", type_="type__value",)
+ )
+ response = await client.get_accelerator_type(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetAcceleratorTypeRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.AcceleratorType)
+ assert response.name == "name_value"
+ assert response.type_ == "type__value"
+
+
+@pytest.mark.asyncio
+async def test_get_accelerator_type_async_from_dict():
+ await test_get_accelerator_type_async(request_type=dict)
+
+
+def test_get_accelerator_type_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetAcceleratorTypeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.AcceleratorType()
+ client.get_accelerator_type(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_get_accelerator_type_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetAcceleratorTypeRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.AcceleratorType()
+ )
+ await client.get_accelerator_type(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_get_accelerator_type_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.AcceleratorType()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.get_accelerator_type(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+def test_get_accelerator_type_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.get_accelerator_type(
+ cloud_tpu.GetAcceleratorTypeRequest(), name="name_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_accelerator_type_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_accelerator_type), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.AcceleratorType()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.AcceleratorType()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.get_accelerator_type(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+@pytest.mark.asyncio
+async def test_get_accelerator_type_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.get_accelerator_type(
+ cloud_tpu.GetAcceleratorTypeRequest(), name="name_value",
+ )
+
+
+def test_list_runtime_versions(
+ transport: str = "grpc", request_type=cloud_tpu.ListRuntimeVersionsRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListRuntimeVersionsResponse(
+ next_page_token="next_page_token_value", unreachable=["unreachable_value"],
+ )
+ response = client.list_runtime_versions(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListRuntimeVersionsRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListRuntimeVersionsPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+def test_list_runtime_versions_from_dict():
+ test_list_runtime_versions(request_type=dict)
+
+
+def test_list_runtime_versions_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ client.list_runtime_versions()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListRuntimeVersionsRequest()
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.ListRuntimeVersionsRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListRuntimeVersionsResponse(
+ next_page_token="next_page_token_value",
+ unreachable=["unreachable_value"],
+ )
+ )
+ response = await client.list_runtime_versions(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.ListRuntimeVersionsRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, pagers.ListRuntimeVersionsAsyncPager)
+ assert response.next_page_token == "next_page_token_value"
+ assert response.unreachable == ["unreachable_value"]
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_async_from_dict():
+ await test_list_runtime_versions_async(request_type=dict)
+
+
+def test_list_runtime_versions_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListRuntimeVersionsRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.ListRuntimeVersionsResponse()
+ client.list_runtime_versions(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.ListRuntimeVersionsRequest()
+
+ request.parent = "parent/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListRuntimeVersionsResponse()
+ )
+ await client.list_runtime_versions(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"]
+
+
+def test_list_runtime_versions_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListRuntimeVersionsResponse()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.list_runtime_versions(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+def test_list_runtime_versions_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.list_runtime_versions(
+ cloud_tpu.ListRuntimeVersionsRequest(), parent="parent_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.ListRuntimeVersionsResponse()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.ListRuntimeVersionsResponse()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.list_runtime_versions(parent="parent_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].parent == "parent_value"
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.list_runtime_versions(
+ cloud_tpu.ListRuntimeVersionsRequest(), parent="parent_value",
+ )
+
+
+def test_list_runtime_versions_pager():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[], next_page_token="def",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ ),
+ RuntimeError,
+ )
+
+ metadata = ()
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)),
+ )
+ pager = client.list_runtime_versions(request={})
+
+ assert pager._metadata == metadata
+
+ results = [i for i in pager]
+ assert len(results) == 6
+ assert all(isinstance(i, cloud_tpu.RuntimeVersion) for i in results)
+
+
+def test_list_runtime_versions_pages():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions), "__call__"
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[], next_page_token="def",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = list(client.list_runtime_versions(request={}).pages)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_async_pager():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions),
+ "__call__",
+ new_callable=mock.AsyncMock,
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[], next_page_token="def",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ ),
+ RuntimeError,
+ )
+ async_pager = await client.list_runtime_versions(request={},)
+ assert async_pager.next_page_token == "abc"
+ responses = []
+ async for response in async_pager:
+ responses.append(response)
+
+ assert len(responses) == 6
+ assert all(isinstance(i, cloud_tpu.RuntimeVersion) for i in responses)
+
+
+@pytest.mark.asyncio
+async def test_list_runtime_versions_async_pages():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.list_runtime_versions),
+ "__call__",
+ new_callable=mock.AsyncMock,
+ ) as call:
+ # Set the response to a series of pages.
+ call.side_effect = (
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ next_page_token="abc",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[], next_page_token="def",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi",
+ ),
+ cloud_tpu.ListRuntimeVersionsResponse(
+ runtime_versions=[
+ cloud_tpu.RuntimeVersion(),
+ cloud_tpu.RuntimeVersion(),
+ ],
+ ),
+ RuntimeError,
+ )
+ pages = []
+ async for page_ in (await client.list_runtime_versions(request={})).pages:
+ pages.append(page_)
+ for page_, token in zip(pages, ["abc", "def", "ghi", ""]):
+ assert page_.raw_page.next_page_token == token
+
+
+def test_get_runtime_version(
+ transport: str = "grpc", request_type=cloud_tpu.GetRuntimeVersionRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.RuntimeVersion(
+ name="name_value", version="version_value",
+ )
+ response = client.get_runtime_version(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetRuntimeVersionRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.RuntimeVersion)
+ assert response.name == "name_value"
+ assert response.version == "version_value"
+
+
+def test_get_runtime_version_from_dict():
+ test_get_runtime_version(request_type=dict)
+
+
+def test_get_runtime_version_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ client.get_runtime_version()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetRuntimeVersionRequest()
+
+
+@pytest.mark.asyncio
+async def test_get_runtime_version_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.GetRuntimeVersionRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.RuntimeVersion(name="name_value", version="version_value",)
+ )
+ response = await client.get_runtime_version(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetRuntimeVersionRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.RuntimeVersion)
+ assert response.name == "name_value"
+ assert response.version == "version_value"
+
+
+@pytest.mark.asyncio
+async def test_get_runtime_version_async_from_dict():
+ await test_get_runtime_version_async(request_type=dict)
+
+
+def test_get_runtime_version_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetRuntimeVersionRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.RuntimeVersion()
+ client.get_runtime_version(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_get_runtime_version_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetRuntimeVersionRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.RuntimeVersion()
+ )
+ await client.get_runtime_version(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_get_runtime_version_flattened():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.RuntimeVersion()
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ client.get_runtime_version(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+def test_get_runtime_version_flattened_error():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ client.get_runtime_version(
+ cloud_tpu.GetRuntimeVersionRequest(), name="name_value",
+ )
+
+
+@pytest.mark.asyncio
+async def test_get_runtime_version_flattened_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_runtime_version), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.RuntimeVersion()
+
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.RuntimeVersion()
+ )
+ # Call the method with a truthy value for each flattened field,
+ # using the keyword arguments to the method.
+ response = await client.get_runtime_version(name="name_value",)
+
+ # Establish that the underlying call was made with the expected
+ # request object values.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0].name == "name_value"
+
+
+@pytest.mark.asyncio
+async def test_get_runtime_version_flattened_error_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Attempting to call a method with both a request object and flattened
+ # fields is an error.
+ with pytest.raises(ValueError):
+ await client.get_runtime_version(
+ cloud_tpu.GetRuntimeVersionRequest(), name="name_value",
+ )
+
+
+def test_get_guest_attributes(
+ transport: str = "grpc", request_type=cloud_tpu.GetGuestAttributesRequest
+):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_guest_attributes), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = cloud_tpu.GetGuestAttributesResponse()
+ response = client.get_guest_attributes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetGuestAttributesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.GetGuestAttributesResponse)
+
+
+def test_get_guest_attributes_from_dict():
+ test_get_guest_attributes(request_type=dict)
+
+
+def test_get_guest_attributes_empty_call():
+ # This test is a coverage failsafe to make sure that totally empty calls,
+ # i.e. request == None and no flattened fields passed, work.
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_guest_attributes), "__call__"
+ ) as call:
+ client.get_guest_attributes()
+ call.assert_called()
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetGuestAttributesRequest()
+
+
+@pytest.mark.asyncio
+async def test_get_guest_attributes_async(
+ transport: str = "grpc_asyncio", request_type=cloud_tpu.GetGuestAttributesRequest
+):
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # Everything is optional in proto3 as far as the runtime is concerned,
+ # and we are mocking out the actual API, so just send an empty request.
+ request = request_type()
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_guest_attributes), "__call__"
+ ) as call:
+ # Designate an appropriate return value for the call.
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.GetGuestAttributesResponse()
+ )
+ response = await client.get_guest_attributes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == cloud_tpu.GetGuestAttributesRequest()
+
+ # Establish that the response is the type that we expect.
+ assert isinstance(response, cloud_tpu.GetGuestAttributesResponse)
+
+
+@pytest.mark.asyncio
+async def test_get_guest_attributes_async_from_dict():
+ await test_get_guest_attributes_async(request_type=dict)
+
+
+def test_get_guest_attributes_field_headers():
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetGuestAttributesRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_guest_attributes), "__call__"
+ ) as call:
+ call.return_value = cloud_tpu.GetGuestAttributesResponse()
+ client.get_guest_attributes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls) == 1
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+@pytest.mark.asyncio
+async def test_get_guest_attributes_field_headers_async():
+ client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),)
+
+ # Any value that is part of the HTTP/1.1 URI should be sent as
+ # a field header. Set these to a non-empty value.
+ request = cloud_tpu.GetGuestAttributesRequest()
+
+ request.name = "name/value"
+
+ # Mock the actual call within the gRPC stub, and fake the request.
+ with mock.patch.object(
+ type(client.transport.get_guest_attributes), "__call__"
+ ) as call:
+ call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(
+ cloud_tpu.GetGuestAttributesResponse()
+ )
+ await client.get_guest_attributes(request)
+
+ # Establish that the underlying gRPC stub method was called.
+ assert len(call.mock_calls)
+ _, args, _ = call.mock_calls[0]
+ assert args[0] == request
+
+ # Establish that the field header was sent.
+ _, _, kw = call.mock_calls[0]
+ assert ("x-goog-request-params", "name=name/value",) in kw["metadata"]
+
+
+def test_credentials_transport_error():
+ # It is an error to provide credentials and a transport instance.
+ transport = transports.TpuGrpcTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport,
+ )
+
+ # It is an error to provide a credentials file and a transport instance.
+ transport = transports.TpuGrpcTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ client = TpuClient(
+ client_options={"credentials_file": "credentials.json"},
+ transport=transport,
+ )
+
+ # It is an error to provide scopes and a transport instance.
+ transport = transports.TpuGrpcTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ with pytest.raises(ValueError):
+ client = TpuClient(client_options={"scopes": ["1", "2"]}, transport=transport,)
+
+
+def test_transport_instance():
+ # A client may be instantiated with a custom transport instance.
+ transport = transports.TpuGrpcTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ client = TpuClient(transport=transport)
+ assert client.transport is transport
+
+
+def test_transport_get_channel():
+ # A client may be instantiated with a custom transport instance.
+ transport = transports.TpuGrpcTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ channel = transport.grpc_channel
+ assert channel
+
+ transport = transports.TpuGrpcAsyncIOTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+ channel = transport.grpc_channel
+ assert channel
+
+
+@pytest.mark.parametrize(
+ "transport_class",
+ [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,],
+)
+def test_transport_adc(transport_class):
+ # Test default credentials are used if not provided.
+ with mock.patch.object(google.auth, "default") as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport_class()
+ adc.assert_called_once()
+
+
+def test_transport_grpc_default():
+ # A client should use the gRPC transport by default.
+ client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),)
+ assert isinstance(client.transport, transports.TpuGrpcTransport,)
+
+
+def test_tpu_base_transport_error():
+ # Passing both a credentials object and credentials_file should raise an error
+ with pytest.raises(core_exceptions.DuplicateCredentialArgs):
+ transport = transports.TpuTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ credentials_file="credentials.json",
+ )
+
+
+def test_tpu_base_transport():
+ # Instantiate the base transport.
+ with mock.patch(
+ "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport.__init__"
+ ) as Transport:
+ Transport.return_value = None
+ transport = transports.TpuTransport(
+ credentials=ga_credentials.AnonymousCredentials(),
+ )
+
+ # Every method on the transport should just blindly
+ # raise NotImplementedError.
+ methods = (
+ "list_nodes",
+ "get_node",
+ "create_node",
+ "delete_node",
+ "stop_node",
+ "start_node",
+ "update_node",
+ "generate_service_identity",
+ "list_accelerator_types",
+ "get_accelerator_type",
+ "list_runtime_versions",
+ "get_runtime_version",
+ "get_guest_attributes",
+ )
+ for method in methods:
+ with pytest.raises(NotImplementedError):
+ getattr(transport, method)(request=object())
+
+ with pytest.raises(NotImplementedError):
+ transport.close()
+
+ # Additionally, the LRO client (a property) should
+ # also raise NotImplementedError
+ with pytest.raises(NotImplementedError):
+ transport.operations_client
+
+
+@requires_google_auth_gte_1_25_0
+def test_tpu_base_transport_with_credentials_file():
+ # Instantiate the base transport with a credentials file
+ with mock.patch.object(
+ google.auth, "load_credentials_from_file", autospec=True
+ ) as load_creds, mock.patch(
+ "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages"
+ ) as Transport:
+ Transport.return_value = None
+ load_creds.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport = transports.TpuTransport(
+ credentials_file="credentials.json", quota_project_id="octopus",
+ )
+ load_creds.assert_called_once_with(
+ "credentials.json",
+ scopes=None,
+ default_scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id="octopus",
+ )
+
+
+@requires_google_auth_lt_1_25_0
+def test_tpu_base_transport_with_credentials_file_old_google_auth():
+ # Instantiate the base transport with a credentials file
+ with mock.patch.object(
+ google.auth, "load_credentials_from_file", autospec=True
+ ) as load_creds, mock.patch(
+ "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages"
+ ) as Transport:
+ Transport.return_value = None
+ load_creds.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport = transports.TpuTransport(
+ credentials_file="credentials.json", quota_project_id="octopus",
+ )
+ load_creds.assert_called_once_with(
+ "credentials.json",
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id="octopus",
+ )
+
+
+def test_tpu_base_transport_with_adc():
+ # Test the default credentials are used if credentials and credentials_file are None.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch(
+ "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages"
+ ) as Transport:
+ Transport.return_value = None
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport = transports.TpuTransport()
+ adc.assert_called_once()
+
+
+@requires_google_auth_gte_1_25_0
+def test_tpu_auth_adc():
+ # If no credentials are provided, we should use ADC credentials.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ TpuClient()
+ adc.assert_called_once_with(
+ scopes=None,
+ default_scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id=None,
+ )
+
+
+@requires_google_auth_lt_1_25_0
+def test_tpu_auth_adc_old_google_auth():
+ # If no credentials are provided, we should use ADC credentials.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ TpuClient()
+ adc.assert_called_once_with(
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id=None,
+ )
+
+
+@pytest.mark.parametrize(
+ "transport_class",
+ [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,],
+)
+@requires_google_auth_gte_1_25_0
+def test_tpu_transport_auth_adc(transport_class):
+ # If credentials and host are not provided, the transport class should use
+ # ADC credentials.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport_class(quota_project_id="octopus", scopes=["1", "2"])
+ adc.assert_called_once_with(
+ scopes=["1", "2"],
+ default_scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id="octopus",
+ )
+
+
+@pytest.mark.parametrize(
+ "transport_class",
+ [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,],
+)
+@requires_google_auth_lt_1_25_0
+def test_tpu_transport_auth_adc_old_google_auth(transport_class):
+ # If credentials and host are not provided, the transport class should use
+ # ADC credentials.
+ with mock.patch.object(google.auth, "default", autospec=True) as adc:
+ adc.return_value = (ga_credentials.AnonymousCredentials(), None)
+ transport_class(quota_project_id="octopus")
+ adc.assert_called_once_with(
+ scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ quota_project_id="octopus",
+ )
+
+
+@pytest.mark.parametrize(
+ "transport_class,grpc_helpers",
+ [
+ (transports.TpuGrpcTransport, grpc_helpers),
+ (transports.TpuGrpcAsyncIOTransport, grpc_helpers_async),
+ ],
+)
+def test_tpu_transport_create_channel(transport_class, grpc_helpers):
+ # If credentials and host are not provided, the transport class should use
+ # ADC credentials.
+ with mock.patch.object(
+ google.auth, "default", autospec=True
+ ) as adc, mock.patch.object(
+ grpc_helpers, "create_channel", autospec=True
+ ) as create_channel:
+ creds = ga_credentials.AnonymousCredentials()
+ adc.return_value = (creds, None)
+ transport_class(quota_project_id="octopus", scopes=["1", "2"])
+
+ create_channel.assert_called_with(
+ "tpu.googleapis.com:443",
+ credentials=creds,
+ credentials_file=None,
+ quota_project_id="octopus",
+ default_scopes=("https://www.googleapis.com/auth/cloud-platform",),
+ scopes=["1", "2"],
+ default_host="tpu.googleapis.com",
+ ssl_credentials=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+
+@pytest.mark.parametrize(
+ "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport]
+)
+def test_tpu_grpc_transport_client_cert_source_for_mtls(transport_class):
+ cred = ga_credentials.AnonymousCredentials()
+
+ # Check ssl_channel_credentials is used if provided.
+ with mock.patch.object(transport_class, "create_channel") as mock_create_channel:
+ mock_ssl_channel_creds = mock.Mock()
+ transport_class(
+ host="squid.clam.whelk",
+ credentials=cred,
+ ssl_channel_credentials=mock_ssl_channel_creds,
+ )
+ mock_create_channel.assert_called_once_with(
+ "squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=None,
+ ssl_credentials=mock_ssl_channel_creds,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls
+ # is used.
+ with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()):
+ with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred:
+ transport_class(
+ credentials=cred,
+ client_cert_source_for_mtls=client_cert_source_callback,
+ )
+ expected_cert, expected_key = client_cert_source_callback()
+ mock_ssl_cred.assert_called_once_with(
+ certificate_chain=expected_cert, private_key=expected_key
+ )
+
+
+def test_tpu_host_no_port():
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ client_options=client_options.ClientOptions(api_endpoint="tpu.googleapis.com"),
+ )
+ assert client.transport._host == "tpu.googleapis.com:443"
+
+
+def test_tpu_host_with_port():
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(),
+ client_options=client_options.ClientOptions(
+ api_endpoint="tpu.googleapis.com:8000"
+ ),
+ )
+ assert client.transport._host == "tpu.googleapis.com:8000"
+
+
+def test_tpu_grpc_transport_channel():
+ channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials())
+
+ # Check that channel is used if provided.
+ transport = transports.TpuGrpcTransport(host="squid.clam.whelk", channel=channel,)
+ assert transport.grpc_channel == channel
+ assert transport._host == "squid.clam.whelk:443"
+ assert transport._ssl_channel_credentials == None
+
+
+def test_tpu_grpc_asyncio_transport_channel():
+ channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials())
+
+ # Check that channel is used if provided.
+ transport = transports.TpuGrpcAsyncIOTransport(
+ host="squid.clam.whelk", channel=channel,
+ )
+ assert transport.grpc_channel == channel
+ assert transport._host == "squid.clam.whelk:443"
+ assert transport._ssl_channel_credentials == None
+
+
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
+@pytest.mark.parametrize(
+ "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport]
+)
+def test_tpu_transport_channel_mtls_with_client_cert_source(transport_class):
+ with mock.patch(
+ "grpc.ssl_channel_credentials", autospec=True
+ ) as grpc_ssl_channel_cred:
+ with mock.patch.object(
+ transport_class, "create_channel"
+ ) as grpc_create_channel:
+ mock_ssl_cred = mock.Mock()
+ grpc_ssl_channel_cred.return_value = mock_ssl_cred
+
+ mock_grpc_channel = mock.Mock()
+ grpc_create_channel.return_value = mock_grpc_channel
+
+ cred = ga_credentials.AnonymousCredentials()
+ with pytest.warns(DeprecationWarning):
+ with mock.patch.object(google.auth, "default") as adc:
+ adc.return_value = (cred, None)
+ transport = transport_class(
+ host="squid.clam.whelk",
+ api_mtls_endpoint="mtls.squid.clam.whelk",
+ client_cert_source=client_cert_source_callback,
+ )
+ adc.assert_called_once()
+
+ grpc_ssl_channel_cred.assert_called_once_with(
+ certificate_chain=b"cert bytes", private_key=b"key bytes"
+ )
+ grpc_create_channel.assert_called_once_with(
+ "mtls.squid.clam.whelk:443",
+ credentials=cred,
+ credentials_file=None,
+ scopes=None,
+ ssl_credentials=mock_ssl_cred,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+ assert transport.grpc_channel == mock_grpc_channel
+ assert transport._ssl_channel_credentials == mock_ssl_cred
+
+
+# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are
+# removed from grpc/grpc_asyncio transport constructor.
+@pytest.mark.parametrize(
+ "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport]
+)
+def test_tpu_transport_channel_mtls_with_adc(transport_class):
+ mock_ssl_cred = mock.Mock()
+ with mock.patch.multiple(
+ "google.auth.transport.grpc.SslCredentials",
+ __init__=mock.Mock(return_value=None),
+ ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred),
+ ):
+ with mock.patch.object(
+ transport_class, "create_channel"
+ ) as grpc_create_channel:
+ mock_grpc_channel = mock.Mock()
+ grpc_create_channel.return_value = mock_grpc_channel
+ mock_cred = mock.Mock()
+
+ with pytest.warns(DeprecationWarning):
+ transport = transport_class(
+ host="squid.clam.whelk",
+ credentials=mock_cred,
+ api_mtls_endpoint="mtls.squid.clam.whelk",
+ client_cert_source=None,
+ )
+
+ grpc_create_channel.assert_called_once_with(
+ "mtls.squid.clam.whelk:443",
+ credentials=mock_cred,
+ credentials_file=None,
+ scopes=None,
+ ssl_credentials=mock_ssl_cred,
+ quota_project_id=None,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+ assert transport.grpc_channel == mock_grpc_channel
+
+
+def test_tpu_grpc_lro_client():
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc",
+ )
+ transport = client.transport
+
+ # Ensure that we have a api-core operations client.
+ assert isinstance(transport.operations_client, operations_v1.OperationsClient,)
+
+ # Ensure that subsequent calls to the property send the exact same object.
+ assert transport.operations_client is transport.operations_client
+
+
+def test_tpu_grpc_lro_async_client():
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
+ )
+ transport = client.transport
+
+ # Ensure that we have a api-core operations client.
+ assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,)
+
+ # Ensure that subsequent calls to the property send the exact same object.
+ assert transport.operations_client is transport.operations_client
+
+
+def test_accelerator_type_path():
+ project = "squid"
+ location = "clam"
+ accelerator_type = "whelk"
+ expected = "projects/{project}/locations/{location}/acceleratorTypes/{accelerator_type}".format(
+ project=project, location=location, accelerator_type=accelerator_type,
+ )
+ actual = TpuClient.accelerator_type_path(project, location, accelerator_type)
+ assert expected == actual
+
+
+def test_parse_accelerator_type_path():
+ expected = {
+ "project": "octopus",
+ "location": "oyster",
+ "accelerator_type": "nudibranch",
+ }
+ path = TpuClient.accelerator_type_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_accelerator_type_path(path)
+ assert expected == actual
+
+
+def test_node_path():
+ project = "cuttlefish"
+ location = "mussel"
+ node = "winkle"
+ expected = "projects/{project}/locations/{location}/nodes/{node}".format(
+ project=project, location=location, node=node,
+ )
+ actual = TpuClient.node_path(project, location, node)
+ assert expected == actual
+
+
+def test_parse_node_path():
+ expected = {
+ "project": "nautilus",
+ "location": "scallop",
+ "node": "abalone",
+ }
+ path = TpuClient.node_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_node_path(path)
+ assert expected == actual
+
+
+def test_runtime_version_path():
+ project = "squid"
+ location = "clam"
+ runtime_version = "whelk"
+ expected = "projects/{project}/locations/{location}/runtimeVersions/{runtime_version}".format(
+ project=project, location=location, runtime_version=runtime_version,
+ )
+ actual = TpuClient.runtime_version_path(project, location, runtime_version)
+ assert expected == actual
+
+
+def test_parse_runtime_version_path():
+ expected = {
+ "project": "octopus",
+ "location": "oyster",
+ "runtime_version": "nudibranch",
+ }
+ path = TpuClient.runtime_version_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_runtime_version_path(path)
+ assert expected == actual
+
+
+def test_common_billing_account_path():
+ billing_account = "cuttlefish"
+ expected = "billingAccounts/{billing_account}".format(
+ billing_account=billing_account,
+ )
+ actual = TpuClient.common_billing_account_path(billing_account)
+ assert expected == actual
+
+
+def test_parse_common_billing_account_path():
+ expected = {
+ "billing_account": "mussel",
+ }
+ path = TpuClient.common_billing_account_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_common_billing_account_path(path)
+ assert expected == actual
+
+
+def test_common_folder_path():
+ folder = "winkle"
+ expected = "folders/{folder}".format(folder=folder,)
+ actual = TpuClient.common_folder_path(folder)
+ assert expected == actual
+
+
+def test_parse_common_folder_path():
+ expected = {
+ "folder": "nautilus",
+ }
+ path = TpuClient.common_folder_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_common_folder_path(path)
+ assert expected == actual
+
+
+def test_common_organization_path():
+ organization = "scallop"
+ expected = "organizations/{organization}".format(organization=organization,)
+ actual = TpuClient.common_organization_path(organization)
+ assert expected == actual
+
+
+def test_parse_common_organization_path():
+ expected = {
+ "organization": "abalone",
+ }
+ path = TpuClient.common_organization_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_common_organization_path(path)
+ assert expected == actual
+
+
+def test_common_project_path():
+ project = "squid"
+ expected = "projects/{project}".format(project=project,)
+ actual = TpuClient.common_project_path(project)
+ assert expected == actual
+
+
+def test_parse_common_project_path():
+ expected = {
+ "project": "clam",
+ }
+ path = TpuClient.common_project_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_common_project_path(path)
+ assert expected == actual
+
+
+def test_common_location_path():
+ project = "whelk"
+ location = "octopus"
+ expected = "projects/{project}/locations/{location}".format(
+ project=project, location=location,
+ )
+ actual = TpuClient.common_location_path(project, location)
+ assert expected == actual
+
+
+def test_parse_common_location_path():
+ expected = {
+ "project": "oyster",
+ "location": "nudibranch",
+ }
+ path = TpuClient.common_location_path(**expected)
+
+ # Check that the path construction is reversible.
+ actual = TpuClient.parse_common_location_path(path)
+ assert expected == actual
+
+
+def test_client_withDEFAULT_CLIENT_INFO():
+ client_info = gapic_v1.client_info.ClientInfo()
+
+ with mock.patch.object(transports.TpuTransport, "_prep_wrapped_messages") as prep:
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
+ )
+ prep.assert_called_once_with(client_info)
+
+ with mock.patch.object(transports.TpuTransport, "_prep_wrapped_messages") as prep:
+ transport_class = TpuClient.get_transport_class()
+ transport = transport_class(
+ credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
+ )
+ prep.assert_called_once_with(client_info)
+
+
+@pytest.mark.asyncio
+async def test_transport_close_async():
+ client = TpuAsyncClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
+ )
+ with mock.patch.object(
+ type(getattr(client.transport, "grpc_channel")), "close"
+ ) as close:
+ async with client:
+ close.assert_not_called()
+ close.assert_called_once()
+
+
+def test_transport_close():
+ transports = {
+ "grpc": "_grpc_channel",
+ }
+
+ for transport, close_name in transports.items():
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport
+ )
+ with mock.patch.object(
+ type(getattr(client.transport, close_name)), "close"
+ ) as close:
+ with client:
+ close.assert_not_called()
+ close.assert_called_once()
+
+
+def test_client_ctx():
+ transports = [
+ "grpc",
+ ]
+ for transport in transports:
+ client = TpuClient(
+ credentials=ga_credentials.AnonymousCredentials(), transport=transport
+ )
+ # Test client calls underlying transport.
+ with mock.patch.object(type(client.transport), "close") as close:
+ close.assert_not_called()
+ with client:
+ pass
+ close.assert_called()
From e8e77451a83a3d85a78bb21043c2fe08a1ab9180 Mon Sep 17 00:00:00 2001
From: "release-please[bot]"
<55107282+release-please[bot]@users.noreply.github.com>
Date: Mon, 18 Oct 2021 19:36:10 +0000
Subject: [PATCH 6/6] chore: release 1.2.0 (#54)
:robot: I have created a release \*beep\* \*boop\*
---
## [1.2.0](https://www.github.com/googleapis/python-tpu/compare/v1.1.0...v1.2.0) (2021-10-15)
### Features
* add support for python 3.10 ([#52](https://www.github.com/googleapis/python-tpu/issues/52)) ([18b9ee0](https://www.github.com/googleapis/python-tpu/commit/18b9ee0cff03b4f97071ef6c7a2bc3e613a01242))
* add TPU v2alpha1 ([#55](https://www.github.com/googleapis/python-tpu/issues/55)) ([72e3e8b](https://www.github.com/googleapis/python-tpu/commit/72e3e8b955690b5f180af89a0a15a8870fd556a8))
---
This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please).
---
CHANGELOG.md | 8 ++++++++
setup.py | 2 +-
2 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index cf49633..5edd9cb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,13 @@
# Changelog
+## [1.2.0](https://www.github.com/googleapis/python-tpu/compare/v1.1.0...v1.2.0) (2021-10-15)
+
+
+### Features
+
+* add support for python 3.10 ([#52](https://www.github.com/googleapis/python-tpu/issues/52)) ([18b9ee0](https://www.github.com/googleapis/python-tpu/commit/18b9ee0cff03b4f97071ef6c7a2bc3e613a01242))
+* add TPU v2alpha1 ([#55](https://www.github.com/googleapis/python-tpu/issues/55)) ([72e3e8b](https://www.github.com/googleapis/python-tpu/commit/72e3e8b955690b5f180af89a0a15a8870fd556a8))
+
## [1.1.0](https://www.github.com/googleapis/python-tpu/compare/v1.0.2...v1.1.0) (2021-10-07)
diff --git a/setup.py b/setup.py
index 1276af2..efa88ac 100644
--- a/setup.py
+++ b/setup.py
@@ -22,7 +22,7 @@
name = "google-cloud-tpu"
description = "Cloud TPU API client library"
-version = "1.1.0"
+version = "1.2.0"
release_status = "Development Status :: 5 - Production/Stable"
url = "https://github.com/googleapis/python-tpu"
dependencies = [