diff --git a/.flake8 b/.flake8
index 29227d4cf4..2e43874986 100644
--- a/.flake8
+++ b/.flake8
@@ -16,7 +16,7 @@
# Generated by synthtool. DO NOT EDIT!
[flake8]
-ignore = E203, E266, E501, W503
+ignore = E203, E231, E266, E501, W503
exclude =
# Exclude generated code.
**/proto/**
diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml
new file mode 100644
index 0000000000..757c9dca75
--- /dev/null
+++ b/.github/.OwlBot.lock.yaml
@@ -0,0 +1,17 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+docker:
+ image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
+ digest: sha256:81ed5ecdfc7cac5b699ba4537376f3563f6f04122c4ec9e735d3b3dc1d43dd32
+# created: 2022-05-05T22:08:23.383410683Z
diff --git a/.github/.OwlBot.yaml b/.github/.OwlBot.yaml
new file mode 100644
index 0000000000..62483abbee
--- /dev/null
+++ b/.github/.OwlBot.yaml
@@ -0,0 +1,28 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+docker:
+ image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest
+
+deep-remove-regex:
+ - /owl-bot-staging
+
+deep-copy-regex:
+ - source: /google/cloud/aiplatform/(v.*)/aiplatform-.*-py/(.*)
+ dest: /owl-bot-staging/$1/$2
+ - source: /google/cloud/aiplatform/v*/schema/.*/.*-py/(google/cloud/.*)
+ dest: /owl-bot-staging/$1
+
+begin-after-commit-hash: 7774246dfb7839067cd64bba0600089b1c91bd85
+
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index f0c14d5c2b..14998d13e5 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -4,18 +4,28 @@
# For syntax help see:
# https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners#codeowners-syntax
-# yoshi-python is the default owner
-* @googleapis/yoshi-python
+# @googleapis/cdpe-cloudai and yoshi-python are the default owners
+* @googleapis/cdpe-cloudai @googleapis/yoshi-python
# The AI Platform GAPIC libraries are owned by Cloud AI DPE
-/google/cloud/aiplatform* @googleapis/cdpe-cloudai
+/google/cloud/aiplatform_*/** @googleapis/cdpe-cloudai
-# The Vertex SDK is owned by Model Builder SDK Dev team
-/google/cloud/aiplatform/* @googleapis/cloud-aiplatform-model-builder-sdk
-/tests/unit/aiplatform/* @googleapis/cloud-aiplatform-model-builder-sdk
+# The Vertex SDK is owned by Vertex SDK Dev team
+/google/cloud/aiplatform/** @googleapis/cloud-aiplatform-model-builder-sdk
+/tests/system/aiplatform/** @googleapis/cloud-aiplatform-model-builder-sdk
+/tests/unit/aiplatform/** @googleapis/cloud-aiplatform-model-builder-sdk
# The Cloud AI DPE team is the default owner for samples
-/samples/**/*.py @googleapis/cdpe-cloudai @googleapis/python-samples-owners
+/samples/**/*.py @googleapis/cdpe-cloudai @googleapis/python-samples-reviewers
+/.sample_configs/** @googleapis/cdpe-cloudai
# The enhanced client library tests are owned by Cloud AI DPE
/tests/unit/enhanced_library/*.py @googleapis/cdpe-cloudai
+
+# Core library files owned by Cloud AI DPE and Vertex SDK Dev teams
+CHANGELOG.md @googleapis/cloud-aiplatform-model-builder-sdk @googleapis/cdpe-cloudai
+README.rst @googleapis/cloud-aiplatform-model-builder-sdk @googleapis/cdpe-cloudai
+setup.py @googleapis/cloud-aiplatform-model-builder-sdk @googleapis/cdpe-cloudai
+
+# Vertex AI product team-specific ownership
+/google/cloud/aiplatform/constants/prediction.py @googleapis/vertex-prediction-team
diff --git a/.github/auto-approve.yml b/.github/auto-approve.yml
new file mode 100644
index 0000000000..311ebbb853
--- /dev/null
+++ b/.github/auto-approve.yml
@@ -0,0 +1,3 @@
+# https://github.com/googleapis/repo-automation-bots/tree/main/packages/auto-approve
+processes:
+ - "OwlBotTemplateChanges"
diff --git a/.github/auto-label.yaml b/.github/auto-label.yaml
new file mode 100644
index 0000000000..41bff0b537
--- /dev/null
+++ b/.github/auto-label.yaml
@@ -0,0 +1,15 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+requestsize:
+ enabled: true
diff --git a/.github/release-please.yml b/.github/release-please.yml
index 4507ad0598..6def37a84c 100644
--- a/.github/release-please.yml
+++ b/.github/release-please.yml
@@ -1 +1,8 @@
releaseType: python
+handleGHRelease: true
+# NOTE: this section is generated by synthtool.languages.python
+# See https://github.com/googleapis/synthtool/blob/master/synthtool/languages/python.py
+branches:
+- branch: v0
+ handleGHRelease: true
+ releaseType: python
diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml
new file mode 100644
index 0000000000..d4ca94189e
--- /dev/null
+++ b/.github/release-trigger.yml
@@ -0,0 +1 @@
+enabled: true
diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml
index 1e00173609..e0ac340c75 100644
--- a/.github/sync-repo-settings.yaml
+++ b/.github/sync-repo-settings.yaml
@@ -1,11 +1,12 @@
-# https://github.com/googleapis/repo-automation-bots/tree/master/packages/sync-repo-settings
-# Rules for master branch protection
+# https://github.com/googleapis/repo-automation-bots/tree/main/packages/sync-repo-settings
+# Rules for main branch protection
mergeCommitAllowed: true
branchProtectionRules:
# Identifies the protection rule pattern. Name of the branch to be protected.
-# Defaults to `master`
-- pattern: master
+# Defaults to `main`
+- pattern: main
+ requiresCodeOwnerReviews: true
+ requiresStrictStatusChecks: true
requiredStatusCheckContexts:
- - 'Kokoro'
- 'cla/google'
- - 'Samples - Lint'
+ - 'Presubmit - Unit Tests'
diff --git a/.kokoro/build.sh b/.kokoro/build.sh
index 35e4a0f6ce..32e6d625b4 100755
--- a/.kokoro/build.sh
+++ b/.kokoro/build.sh
@@ -41,7 +41,7 @@ python3 -m pip install --upgrade --quiet nox
python3 -m nox --version
# If this is a continuous build, send the test log to the FlakyBot.
-# See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot.
+# See https://github.com/googleapis/repo-automation-bots/tree/main/packages/flakybot.
if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"continuous"* ]]; then
cleanup() {
chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot
diff --git a/.kokoro/continuous/common.cfg b/.kokoro/continuous/common.cfg
index 7e71cb43e9..c8f353660a 100644
--- a/.kokoro/continuous/common.cfg
+++ b/.kokoro/continuous/common.cfg
@@ -25,3 +25,11 @@ env_vars: {
key: "TRAMPOLINE_BUILD_FILE"
value: "github/python-aiplatform/.kokoro/build.sh"
}
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+}
+env_vars {
+ key: "_VPC_NETWORK_URI"
+ value: "projects/580378083368/global/networks/system-tests"
+}
diff --git a/.kokoro/continuous/system.cfg b/.kokoro/continuous/system.cfg
new file mode 100644
index 0000000000..1fdefbaa72
--- /dev/null
+++ b/.kokoro/continuous/system.cfg
@@ -0,0 +1,15 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "NOX_SESSION"
+ value: "system-3.8"
+}
+
+# Run system tests in parallel, splitting up by file
+env_vars: {
+ key: "PYTEST_ADDOPTS"
+ value: "-n=auto --dist=loadscope"
+}
+
+# Kokoro VM timeout of 5 hours for system tests
+timeout_mins: 300
diff --git a/.kokoro/continuous/unit.cfg b/.kokoro/continuous/unit.cfg
new file mode 100644
index 0000000000..57188a42d1
--- /dev/null
+++ b/.kokoro/continuous/unit.cfg
@@ -0,0 +1,13 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Run all unit test sessions, in Python 3.7 to 3.9
+env_vars: {
+ key: "NOX_SESSION"
+ value: "unit"
+}
+
+# Run unit tests in parallel, splitting up by test
+env_vars: {
+ key: "PYTEST_ADDOPTS"
+ value: "-n=auto --dist=loadscope"
+}
diff --git a/.kokoro/docker/docs/Dockerfile b/.kokoro/docker/docs/Dockerfile
index 412b0b56a9..238b87b9d1 100644
--- a/.kokoro/docker/docs/Dockerfile
+++ b/.kokoro/docker/docs/Dockerfile
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ubuntu:20.04
+from ubuntu:22.04
ENV DEBIAN_FRONTEND noninteractive
@@ -40,6 +40,7 @@ RUN apt-get update \
libssl-dev \
libsqlite3-dev \
portaudio19-dev \
+ python3-distutils \
redis-server \
software-properties-common \
ssh \
@@ -59,40 +60,24 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/* \
&& rm -f /var/cache/apt/archives/*.deb
+###################### Install python 3.8.11
-COPY fetch_gpg_keys.sh /tmp
-# Install the desired versions of Python.
-RUN set -ex \
- && export GNUPGHOME="$(mktemp -d)" \
- && echo "disable-ipv6" >> "${GNUPGHOME}/dirmngr.conf" \
- && /tmp/fetch_gpg_keys.sh \
- && for PYTHON_VERSION in 3.7.8 3.8.5; do \
- wget --no-check-certificate -O python-${PYTHON_VERSION}.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \
- && wget --no-check-certificate -O python-${PYTHON_VERSION}.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \
- && gpg --batch --verify python-${PYTHON_VERSION}.tar.xz.asc python-${PYTHON_VERSION}.tar.xz \
- && rm -r python-${PYTHON_VERSION}.tar.xz.asc \
- && mkdir -p /usr/src/python-${PYTHON_VERSION} \
- && tar -xJC /usr/src/python-${PYTHON_VERSION} --strip-components=1 -f python-${PYTHON_VERSION}.tar.xz \
- && rm python-${PYTHON_VERSION}.tar.xz \
- && cd /usr/src/python-${PYTHON_VERSION} \
- && ./configure \
- --enable-shared \
- # This works only on Python 2.7 and throws a warning on every other
- # version, but seems otherwise harmless.
- --enable-unicode=ucs4 \
- --with-system-ffi \
- --without-ensurepip \
- && make -j$(nproc) \
- && make install \
- && ldconfig \
- ; done \
- && rm -rf "${GNUPGHOME}" \
- && rm -rf /usr/src/python* \
- && rm -rf ~/.cache/
+# Download python 3.8.11
+RUN wget https://www.python.org/ftp/python/3.8.11/Python-3.8.11.tgz
+# Extract files
+RUN tar -xvf Python-3.8.11.tgz
+
+# Install python 3.8.11
+RUN ./Python-3.8.11/configure --enable-optimizations
+RUN make altinstall
+
+###################### Install pip
RUN wget -O /tmp/get-pip.py 'https://bootstrap.pypa.io/get-pip.py' \
- && python3.7 /tmp/get-pip.py \
- && python3.8 /tmp/get-pip.py \
+ && python3 /tmp/get-pip.py \
&& rm /tmp/get-pip.py
-CMD ["python3.7"]
+# Test pip
+RUN python3 -m pip
+
+CMD ["python3.8"]
diff --git a/.kokoro/docker/docs/fetch_gpg_keys.sh b/.kokoro/docker/docs/fetch_gpg_keys.sh
deleted file mode 100755
index d653dd868e..0000000000
--- a/.kokoro/docker/docs/fetch_gpg_keys.sh
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/bin/bash
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# A script to fetch gpg keys with retry.
-# Avoid jinja parsing the file.
-#
-
-function retry {
- if [[ "${#}" -le 1 ]]; then
- echo "Usage: ${0} retry_count commands.."
- exit 1
- fi
- local retries=${1}
- local command="${@:2}"
- until [[ "${retries}" -le 0 ]]; do
- $command && return 0
- if [[ $? -ne 0 ]]; then
- echo "command failed, retrying"
- ((retries--))
- fi
- done
- return 1
-}
-
-# 3.6.9, 3.7.5 (Ned Deily)
-retry 3 gpg --keyserver ha.pool.sks-keyservers.net --recv-keys \
- 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D
-
-# 3.8.0 (Ćukasz Langa)
-retry 3 gpg --keyserver ha.pool.sks-keyservers.net --recv-keys \
- E3FF2839C048B25C084DEBE9B26995E310250568
-
-#
diff --git a/.kokoro/docs/common.cfg b/.kokoro/docs/common.cfg
index 5adc161f36..fe691d5317 100644
--- a/.kokoro/docs/common.cfg
+++ b/.kokoro/docs/common.cfg
@@ -30,6 +30,7 @@ env_vars: {
env_vars: {
key: "V2_STAGING_BUCKET"
+ # Push google cloud library docs to the Cloud RAD bucket `docs-staging-v2`
value: "docs-staging-v2"
}
diff --git a/.kokoro/presubmit/presubmit.cfg b/.kokoro/presubmit/presubmit.cfg
index 8f43917d92..66f7c8a934 100644
--- a/.kokoro/presubmit/presubmit.cfg
+++ b/.kokoro/presubmit/presubmit.cfg
@@ -1 +1,13 @@
-# Format: //devtools/kokoro/config/proto/build.proto
\ No newline at end of file
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Run all sessions except system tests and docs builds
+env_vars: {
+ key: "NOX_SESSION"
+ value: "unit lint lint_setup_py blacken cover"
+}
+
+# Run unit tests in parallel, splitting up by file
+env_vars: {
+ key: "PYTEST_ADDOPTS"
+ value: "-n=auto --dist=loadscope"
+}
diff --git a/.kokoro/presubmit/release.cfg b/.kokoro/presubmit/release.cfg
new file mode 100644
index 0000000000..fc047df824
--- /dev/null
+++ b/.kokoro/presubmit/release.cfg
@@ -0,0 +1,13 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Run system tests in presubmit for library releases
+env_vars: {
+ key: "NOX_SESSION"
+ value: "system-3.8 unit"
+}
+
+# Run system tests in parallel, splitting up by file
+env_vars: {
+ key: "PYTEST_ADDOPTS"
+ value: "-n=auto --dist=loadscope"
+}
diff --git a/.kokoro/presubmit/system.cfg b/.kokoro/presubmit/system.cfg
new file mode 100644
index 0000000000..29bcaf044c
--- /dev/null
+++ b/.kokoro/presubmit/system.cfg
@@ -0,0 +1,13 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Run system tests when test files are modified
+env_vars: {
+ key: "NOX_SESSION"
+ value: "system-3.8"
+}
+
+# Run system tests in parallel, splitting up by file
+env_vars: {
+ key: "PYTEST_ADDOPTS"
+ value: "-n=auto --dist=loadscope"
+}
diff --git a/.kokoro/release.sh b/.kokoro/release.sh
index 62bdb892ff..7ca397271d 100755
--- a/.kokoro/release.sh
+++ b/.kokoro/release.sh
@@ -26,7 +26,7 @@ python3 -m pip install --upgrade twine wheel setuptools
export PYTHONUNBUFFERED=1
# Move into the package, build the distribution and upload.
-TWINE_PASSWORD=$(cat "${KOKORO_GFILE_DIR}/secret_manager/google-cloud-pypi-token")
+TWINE_PASSWORD=$(cat "${KOKORO_KEYSTORE_DIR}/73713_google-cloud-pypi-token-keystore-1")
cd github/python-aiplatform
python3 setup.py sdist bdist_wheel
twine upload --username __token__ --password "${TWINE_PASSWORD}" dist/*
diff --git a/.kokoro/release/common.cfg b/.kokoro/release/common.cfg
index 5293e75110..08012edcbd 100644
--- a/.kokoro/release/common.cfg
+++ b/.kokoro/release/common.cfg
@@ -23,8 +23,18 @@ env_vars: {
value: "github/python-aiplatform/.kokoro/release.sh"
}
+# Fetch PyPI password
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73713
+ keyname: "google-cloud-pypi-token-keystore-1"
+ }
+ }
+}
+
# Tokens needed to report release status back to GitHub
env_vars: {
key: "SECRET_MANAGER_KEYS"
- value: "releasetool-publish-reporter-app,releasetool-publish-reporter-googleapis-installation,releasetool-publish-reporter-pem,google-cloud-pypi-token"
+ value: "releasetool-publish-reporter-app,releasetool-publish-reporter-googleapis-installation,releasetool-publish-reporter-pem"
}
diff --git a/.kokoro/samples/lint/common.cfg b/.kokoro/samples/lint/common.cfg
index 87d5295efb..a239d54498 100644
--- a/.kokoro/samples/lint/common.cfg
+++ b/.kokoro/samples/lint/common.cfg
@@ -31,4 +31,4 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
# Use the trampoline script to run in docker.
-build_file: "python-aiplatform/.kokoro/trampoline.sh"
\ No newline at end of file
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.10/common.cfg b/.kokoro/samples/python3.10/common.cfg
new file mode 100644
index 0000000000..a49138fd0a
--- /dev/null
+++ b/.kokoro/samples/python3.10/common.cfg
@@ -0,0 +1,40 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Build logs will be here
+action {
+ define_artifacts {
+ regex: "**/*sponge_log.xml"
+ }
+}
+
+# Specify which tests to run
+env_vars: {
+ key: "RUN_TESTS_SESSION"
+ value: "py-3.10"
+}
+
+# Declare build specific Cloud project.
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples.sh"
+}
+
+# Configure the docker image for kokoro-trampoline.
+env_vars: {
+ key: "TRAMPOLINE_IMAGE"
+ value: "gcr.io/cloud-devrel-kokoro-resources/python-samples-testing-docker"
+}
+
+# Download secrets for samples
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
+
+# Download trampoline resources.
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
+
+# Use the trampoline script to run in docker.
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.10/continuous.cfg b/.kokoro/samples/python3.10/continuous.cfg
new file mode 100644
index 0000000000..a1c8d9759c
--- /dev/null
+++ b/.kokoro/samples/python3.10/continuous.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
diff --git a/.kokoro/samples/python3.10/periodic-head.cfg b/.kokoro/samples/python3.10/periodic-head.cfg
new file mode 100644
index 0000000000..88d5235e34
--- /dev/null
+++ b/.kokoro/samples/python3.10/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.6/periodic.cfg b/.kokoro/samples/python3.10/periodic.cfg
similarity index 98%
rename from .kokoro/samples/python3.6/periodic.cfg
rename to .kokoro/samples/python3.10/periodic.cfg
index 50fec96497..71cd1e597e 100644
--- a/.kokoro/samples/python3.6/periodic.cfg
+++ b/.kokoro/samples/python3.10/periodic.cfg
@@ -3,4 +3,4 @@
env_vars: {
key: "INSTALL_LIBRARY_FROM_SOURCE"
value: "False"
-}
\ No newline at end of file
+}
diff --git a/.kokoro/samples/python3.10/presubmit.cfg b/.kokoro/samples/python3.10/presubmit.cfg
new file mode 100644
index 0000000000..a1c8d9759c
--- /dev/null
+++ b/.kokoro/samples/python3.10/presubmit.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
diff --git a/.kokoro/samples/python3.6/common.cfg b/.kokoro/samples/python3.6/common.cfg
index 9801281f05..72bfadc9f4 100644
--- a/.kokoro/samples/python3.6/common.cfg
+++ b/.kokoro/samples/python3.6/common.cfg
@@ -7,18 +7,18 @@ action {
}
}
-# Declare build specific Cloud project.
-env_vars: {
- key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
- value: "ucaip-sample-tests"
-}
-
# Specify which tests to run
env_vars: {
key: "RUN_TESTS_SESSION"
value: "py-3.6"
}
+# Declare build specific Cloud project.
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+}
+
env_vars: {
key: "TRAMPOLINE_BUILD_FILE"
value: "github/python-aiplatform/.kokoro/test-samples.sh"
@@ -37,4 +37,4 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
# Use the trampoline script to run in docker.
-build_file: "python-aiplatform/.kokoro/trampoline.sh"
\ No newline at end of file
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.6/periodic-head.cfg b/.kokoro/samples/python3.6/periodic-head.cfg
new file mode 100644
index 0000000000..88d5235e34
--- /dev/null
+++ b/.kokoro/samples/python3.6/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.7/common.cfg b/.kokoro/samples/python3.7/common.cfg
index c71cf79c57..cc8296c89d 100644
--- a/.kokoro/samples/python3.7/common.cfg
+++ b/.kokoro/samples/python3.7/common.cfg
@@ -7,18 +7,18 @@ action {
}
}
-# Declare build specific Cloud project.
-env_vars: {
- key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
- value: "ucaip-sample-tests"
-}
-
# Specify which tests to run
env_vars: {
key: "RUN_TESTS_SESSION"
value: "py-3.7"
}
+# Declare build specific Cloud project.
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+}
+
env_vars: {
key: "TRAMPOLINE_BUILD_FILE"
value: "github/python-aiplatform/.kokoro/test-samples.sh"
@@ -37,4 +37,4 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
# Use the trampoline script to run in docker.
-build_file: "python-aiplatform/.kokoro/trampoline.sh"
\ No newline at end of file
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.7/periodic-head.cfg b/.kokoro/samples/python3.7/periodic-head.cfg
new file mode 100644
index 0000000000..88d5235e34
--- /dev/null
+++ b/.kokoro/samples/python3.7/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.7/periodic.cfg b/.kokoro/samples/python3.7/periodic.cfg
index 50fec96497..b196817872 100644
--- a/.kokoro/samples/python3.7/periodic.cfg
+++ b/.kokoro/samples/python3.7/periodic.cfg
@@ -2,5 +2,5 @@
env_vars: {
key: "INSTALL_LIBRARY_FROM_SOURCE"
- value: "False"
-}
\ No newline at end of file
+ value: "True"
+}
diff --git a/.kokoro/samples/python3.8/common.cfg b/.kokoro/samples/python3.8/common.cfg
index 21b411c8e1..a118253a82 100644
--- a/.kokoro/samples/python3.8/common.cfg
+++ b/.kokoro/samples/python3.8/common.cfg
@@ -7,22 +7,16 @@ action {
}
}
-# Declare build specific Cloud project.
-env_vars: {
- key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
- value: "ucaip-sample-tests"
-}
-
# Specify which tests to run
env_vars: {
key: "RUN_TESTS_SESSION"
value: "py-3.8"
}
-# Run tests located under tests/system
+# Declare build specific Cloud project.
env_vars: {
- key: "RUN_SYSTEM_TESTS"
- value: "true"
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
}
env_vars: {
@@ -43,4 +37,4 @@ gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
# Use the trampoline script to run in docker.
-build_file: "python-aiplatform/.kokoro/trampoline.sh"
\ No newline at end of file
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.8/periodic-head.cfg b/.kokoro/samples/python3.8/periodic-head.cfg
new file mode 100644
index 0000000000..88d5235e34
--- /dev/null
+++ b/.kokoro/samples/python3.8/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.8/periodic.cfg b/.kokoro/samples/python3.8/periodic.cfg
index 50fec96497..b196817872 100644
--- a/.kokoro/samples/python3.8/periodic.cfg
+++ b/.kokoro/samples/python3.8/periodic.cfg
@@ -2,5 +2,5 @@
env_vars: {
key: "INSTALL_LIBRARY_FROM_SOURCE"
- value: "False"
-}
\ No newline at end of file
+ value: "True"
+}
diff --git a/.kokoro/samples/python3.9/common.cfg b/.kokoro/samples/python3.9/common.cfg
new file mode 100644
index 0000000000..5a549c80fc
--- /dev/null
+++ b/.kokoro/samples/python3.9/common.cfg
@@ -0,0 +1,40 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+# Build logs will be here
+action {
+ define_artifacts {
+ regex: "**/*sponge_log.xml"
+ }
+}
+
+# Specify which tests to run
+env_vars: {
+ key: "RUN_TESTS_SESSION"
+ value: "py-3.9"
+}
+
+# Declare build specific Cloud project.
+env_vars: {
+ key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ value: "ucaip-sample-tests"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples.sh"
+}
+
+# Configure the docker image for kokoro-trampoline.
+env_vars: {
+ key: "TRAMPOLINE_IMAGE"
+ value: "gcr.io/cloud-devrel-kokoro-resources/python-samples-testing-docker"
+}
+
+# Download secrets for samples
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples"
+
+# Download trampoline resources.
+gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline"
+
+# Use the trampoline script to run in docker.
+build_file: "python-aiplatform/.kokoro/trampoline_v2.sh"
\ No newline at end of file
diff --git a/.kokoro/samples/python3.9/continuous.cfg b/.kokoro/samples/python3.9/continuous.cfg
new file mode 100644
index 0000000000..a1c8d9759c
--- /dev/null
+++ b/.kokoro/samples/python3.9/continuous.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
diff --git a/.kokoro/samples/python3.9/periodic-head.cfg b/.kokoro/samples/python3.9/periodic-head.cfg
new file mode 100644
index 0000000000..88d5235e34
--- /dev/null
+++ b/.kokoro/samples/python3.9/periodic-head.cfg
@@ -0,0 +1,11 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
+
+env_vars: {
+ key: "TRAMPOLINE_BUILD_FILE"
+ value: "github/python-aiplatform/.kokoro/test-samples-against-head.sh"
+}
diff --git a/.kokoro/samples/python3.9/periodic.cfg b/.kokoro/samples/python3.9/periodic.cfg
new file mode 100644
index 0000000000..71cd1e597e
--- /dev/null
+++ b/.kokoro/samples/python3.9/periodic.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "False"
+}
diff --git a/.kokoro/samples/python3.9/presubmit.cfg b/.kokoro/samples/python3.9/presubmit.cfg
new file mode 100644
index 0000000000..a1c8d9759c
--- /dev/null
+++ b/.kokoro/samples/python3.9/presubmit.cfg
@@ -0,0 +1,6 @@
+# Format: //devtools/kokoro/config/proto/build.proto
+
+env_vars: {
+ key: "INSTALL_LIBRARY_FROM_SOURCE"
+ value: "True"
+}
\ No newline at end of file
diff --git a/.kokoro/test-samples-against-head.sh b/.kokoro/test-samples-against-head.sh
index 8f0597f90d..ba3a707b04 100755
--- a/.kokoro/test-samples-against-head.sh
+++ b/.kokoro/test-samples-against-head.sh
@@ -23,6 +23,4 @@ set -eo pipefail
# Enables `**` to include files nested inside sub-folders
shopt -s globstar
-cd github/python-aiplatform
-
exec .kokoro/test-samples-impl.sh
diff --git a/.kokoro/test-samples-impl.sh b/.kokoro/test-samples-impl.sh
index cf5de74c17..8a324c9c7b 100755
--- a/.kokoro/test-samples-impl.sh
+++ b/.kokoro/test-samples-impl.sh
@@ -20,9 +20,9 @@ set -eo pipefail
# Enables `**` to include files nested inside sub-folders
shopt -s globstar
-# Exit early if samples directory doesn't exist
-if [ ! -d "./samples" ]; then
- echo "No tests run. `./samples` not found"
+# Exit early if samples don't exist
+if ! find samples -name 'requirements.txt' | grep -q .; then
+ echo "No tests run. './samples/**/requirements.txt' not found"
exit 0
fi
@@ -80,7 +80,7 @@ for file in samples/**/requirements.txt; do
EXIT=$?
# If this is a periodic build, send the test log to the FlakyBot.
- # See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot.
+ # See https://github.com/googleapis/repo-automation-bots/tree/main/packages/flakybot.
if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then
chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot
$KOKORO_GFILE_DIR/linux_amd64/flakybot
diff --git a/.kokoro/test-samples.sh b/.kokoro/test-samples.sh
index 6bb4d5c30b..11c042d342 100755
--- a/.kokoro/test-samples.sh
+++ b/.kokoro/test-samples.sh
@@ -24,8 +24,6 @@ set -eo pipefail
# Enables `**` to include files nested inside sub-folders
shopt -s globstar
-cd github/python-aiplatform
-
# Run periodic samples tests at latest release
if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then
# preserving the test runner implementation.
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4f00c7cffc..46d237160f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -16,13 +16,13 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v3.4.0
+ rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- repo: https://github.com/psf/black
- rev: 19.10b0
+ rev: 22.3.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
diff --git a/.repo-metadata.json b/.repo-metadata.json
index 46b1493222..d207a35896 100644
--- a/.repo-metadata.json
+++ b/.repo-metadata.json
@@ -2,12 +2,13 @@
"name": "aiplatform",
"name_pretty": "AI Platform",
"product_documentation": "https://cloud.google.com/ai-platform",
- "client_documentation": "https://googleapis.dev/python/aiplatform/latest",
+ "client_documentation": "https://cloud.google.com/python/docs/reference/aiplatform/latest",
"issue_tracker": "https://issuetracker.google.com/savedsearches/559744",
- "release_level": "ga",
+ "release_level": "stable",
"language": "python",
"library_type": "GAPIC_COMBO",
"repo": "googleapis/python-aiplatform",
"distribution_name": "google-cloud-aiplatform",
- "api_id": "aiplatform.googleapis.com"
-}
\ No newline at end of file
+ "api_id": "aiplatform.googleapis.com",
+ "api_shortname": "aiplatform"
+}
diff --git a/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py b/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py
index d03f13dff1..fab2aafc8c 100644
--- a/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py
+++ b/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py
@@ -1,3 +1,18 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
def make_parent(parent: str) -> str:
parent = parent
diff --git a/.sample_configs/process_configs.yaml b/.sample_configs/process_configs.yaml
index 4e6608b4fd..91c8b61c8f 100644
--- a/.sample_configs/process_configs.yaml
+++ b/.sample_configs/process_configs.yaml
@@ -266,7 +266,7 @@ predict_image_classification_sample:
instance_dict: predict.instance.ImageClassificationPredictionInstance
parameters_dict: predict.params.ImageClassificationPredictionParams
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/classification.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/image_classification_1.0.0.yaml
for the format of the predictions.
predict_image_file_sample:
max_depth: 1
@@ -278,7 +278,7 @@ predict_image_object_detection_sample:
instance_dict: predict.instance.ImageObjectDetectionPredictionInstance
parameters_dict: predict.params.ImageObjectDetectionPredictionParams
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/image_object_detection.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/image_object_detection_1.0.0.yaml
for the format of the predictions.
predict_sample:
max_depth: 1
@@ -290,30 +290,27 @@ predict_tabular_classification_sample:
max_depth: 1
resource_name: endpoint
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/tables_classification.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/tabular_classification_1.0.0.yaml
for the format of the predictions.
predict_tabular_forecasting_sample: {}
predict_tabular_regression_sample:
max_depth: 1
resource_name: endpoint
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/tables_regression.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/tabular_regression_1.0.0.yaml
for the format of the predictions.
predict_text_classification_single_label_sample:
max_depth: 1
resource_name: endpoint
schema_types:
instance_dict: predict.instance.TextClassificationPredictionInstance
- comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/text_classification.yaml
- for the format of the predictions.
predict_text_entity_extraction_sample:
max_depth: 1
resource_name: endpoint
schema_types:
instance_dict: predict.instance.TextExtractionPredictionInstance
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/text_extraction.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/text_extraction_1.0.0.yaml
for the format of the predictions.
predict_text_sentiment_analysis_sample:
max_depth: 1
@@ -321,7 +318,7 @@ predict_text_sentiment_analysis_sample:
schema_types:
instance_dict: predict.instance.TextSentimentPredictionInstance
comments:
- predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/text_sentiment.yaml
+ predictions: See gs://google-cloud-aiplatform/schema/predict/prediction/text_sentiment_1.0.0.yaml
for the format of the predictions.
search_migratable_resources_sample: {}
undeploy_model_sample:
diff --git a/.trampolinerc b/.trampolinerc
index 383b6ec89f..0eee72ab62 100644
--- a/.trampolinerc
+++ b/.trampolinerc
@@ -16,15 +16,26 @@
# Add required env vars here.
required_envvars+=(
- "STAGING_BUCKET"
- "V2_STAGING_BUCKET"
)
# Add env vars which are passed down into the container here.
pass_down_envvars+=(
+ "NOX_SESSION"
+ ###############
+ # Docs builds
+ ###############
"STAGING_BUCKET"
"V2_STAGING_BUCKET"
- "NOX_SESSION"
+ ##################
+ # Samples builds
+ ##################
+ "INSTALL_LIBRARY_FROM_SOURCE"
+ "RUN_TESTS_SESSION"
+ "BUILD_SPECIFIC_GCLOUD_PROJECT"
+ # Target directories.
+ "RUN_TESTS_DIRS"
+ # The nox session to run.
+ "RUN_TESTS_SESSION"
)
# Prevent unintentional override on the default image.
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5fa4409fd9..f4c8f96d1d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,612 @@
# Changelog
-### [1.0.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.0.0...v1.0.1) (2021-05-21)
+## [1.14.0](https://github.com/googleapis/python-aiplatform/compare/v1.13.1...v1.14.0) (2022-06-08)
+
+
+### Features
+
+* add a way to easily clone a PipelineJob ([#1239](https://github.com/googleapis/python-aiplatform/issues/1239)) ([efaf6ed](https://github.com/googleapis/python-aiplatform/commit/efaf6edc36262b095aa13d0b40348c20e39b3fc6))
+* add display_name and metadata to ModelEvaluation in aiplatform model_evaluation.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add Examples to Explanation related messages in aiplatform v1beta1 explanation.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* Add hierarchy and window configs to Vertex Forecasting training job ([#1255](https://github.com/googleapis/python-aiplatform/issues/1255)) ([8560fa8](https://github.com/googleapis/python-aiplatform/commit/8560fa88c8e0fe51f2ae56f68be575e85db3696a))
+* add holiday regions for vertex forecasting ([#1253](https://github.com/googleapis/python-aiplatform/issues/1253)) ([0036ab0](https://github.com/googleapis/python-aiplatform/commit/0036ab07004e0c9ae7806c4c2c25f22d5af4a978))
+* add IAM policy to aiplatform_v1beta1.yaml ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add latent_space_source to ExplanationMetadata in aiplatform v1 explanation_metadata.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add latent_space_source to ExplanationMetadata in aiplatform v1beta1 explanation_metadata.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add preset configuration for example-based explanations in aiplatform v1beta1 explanation.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add scaling to OnlineServingConfig in aiplatform v1 featurestore.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add seq2seq forecasting training job ([#1196](https://github.com/googleapis/python-aiplatform/issues/1196)) ([643d335](https://github.com/googleapis/python-aiplatform/commit/643d335693ec57848949ee173401867a1188678b))
+* add successful_forecast_point_count to CompletionStats in completion_stats.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* add template_metadata to PipelineJob in aiplatform v1 pipeline_job.proto ([b6bf6dc](https://github.com/googleapis/python-aiplatform/commit/b6bf6dc643274220e6eeca6479b5f9df61b11d16))
+* Add Vertex Forecasting E2E test. ([#1248](https://github.com/googleapis/python-aiplatform/issues/1248)) ([e82c179](https://github.com/googleapis/python-aiplatform/commit/e82c1792293396045a1032df015a3700fc38609b))
+* Added forecasting snippets and fixed bugs with existing snippets ([#1210](https://github.com/googleapis/python-aiplatform/issues/1210)) ([4e4bff5](https://github.com/googleapis/python-aiplatform/commit/4e4bff5cac3a99e7f55145ab2aee83b20af67060))
+
+
+### Bug Fixes
+
+* change endpoint update method to return resource ([#1409](https://github.com/googleapis/python-aiplatform/issues/1409)) ([44e279b](https://github.com/googleapis/python-aiplatform/commit/44e279b15a1b03bf234111333517153ffdbaf696))
+* Changed system test to use list_models() correctly ([#1397](https://github.com/googleapis/python-aiplatform/issues/1397)) ([a3da19a](https://github.com/googleapis/python-aiplatform/commit/a3da19aac6bdd3fa8d218408582205f7241a4b04))
+* Pinned protobuf to prevent issues with pb files. ([#1398](https://github.com/googleapis/python-aiplatform/issues/1398)) ([7a54637](https://github.com/googleapis/python-aiplatform/commit/7a54637d9b0e7a52ec4648505a6902610c4cc5b7))
+
+
+### Documentation
+
+* fix changelog header to consistent size ([#1404](https://github.com/googleapis/python-aiplatform/issues/1404)) ([f6a7e6f](https://github.com/googleapis/python-aiplatform/commit/f6a7e6f35188d6032fc8b34a3c205b0632029e02))
+
+## [1.13.1](https://github.com/googleapis/python-aiplatform/compare/v1.13.0...v1.13.1) (2022-05-26)
+
+
+### Features
+
+* add batch_size kwarg for batch prediction jobs ([#1194](https://github.com/googleapis/python-aiplatform/issues/1194)) ([50bdb01](https://github.com/googleapis/python-aiplatform/commit/50bdb01504740ed31de788d8a160f3e2be7f55df))
+* add update endpoint ([#1162](https://github.com/googleapis/python-aiplatform/issues/1162)) ([0ecfe1e](https://github.com/googleapis/python-aiplatform/commit/0ecfe1e7ab8687c13cb4267985e8b6ebc7bd2534))
+* support autoscaling metrics when deploying models ([#1197](https://github.com/googleapis/python-aiplatform/issues/1197)) ([095717c](https://github.com/googleapis/python-aiplatform/commit/095717c8b77dc5d66e677413a437ea6ed92e0b1a))
+
+
+### Bug Fixes
+
+* check in service proto file ([#1174](https://github.com/googleapis/python-aiplatform/issues/1174)) ([5fdf151](https://github.com/googleapis/python-aiplatform/commit/5fdf151ee0d0a630c07a75dc8f19906e7ad1aa8a))
+* regenerate pb2 files using grpcio-tools ([#1394](https://github.com/googleapis/python-aiplatform/issues/1394)) ([406c868](https://github.com/googleapis/python-aiplatform/commit/406c868344280d424f4191c98bcbbdeaf947b2d1))
+
+
+### Documentation
+
+* update aiplatform SDK arrangement for Sphinx ([#1163](https://github.com/googleapis/python-aiplatform/issues/1163)) ([e9510ea](https://github.com/googleapis/python-aiplatform/commit/e9510ea6344a296e0c93ddf32280cf4c010ee4f1))
+
+
+### Miscellaneous Chores
+
+* release 1.13.1 ([#1395](https://github.com/googleapis/python-aiplatform/issues/1395)) ([df78407](https://github.com/googleapis/python-aiplatform/commit/df78407b2f14c95c9e84b4b1375a8de5bc9c7bb5))
+
+## [1.13.0](https://github.com/googleapis/python-aiplatform/compare/v1.12.1...v1.13.0) (2022-05-09)
+
+
+### Features
+
+* add ConvexAutomatedStoppingSpec to StudySpec in aiplatform v1 study.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add ConvexAutomatedStoppingSpec to StudySpec in aiplatform v1beta1 study.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add JOB_STATE_UPDATING to JobState in aiplatform v1 job_state.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add JOB_STATE_UPDATING to JobState in aiplatform v1beta1 job_state.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add LatestMonitoringPipelineMetadata to ModelDeploymentMonitoringJob in aiplatform v1beta1 model_deployment_monitoring_job.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add ListModelVersion, DeleteModelVersion, and MergeVersionAliases rpcs to aiplatform v1beta1 model_service.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add MfsMount in aiplatform v1 machine_resources.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add MfsMount in aiplatform v1beta1 machine_resources.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add model_id and parent_model to TrainingPipeline in aiplatform v1beta1 training_pipeline.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add model_version_id to DeployedModel in aiplatform v1beta1 endpoint.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add model_version_id to PredictResponse in aiplatform v1beta1 prediction_service.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add model_version_id to UploadModelRequest and UploadModelResponse in aiplatform v1beta1 model_service.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add nfs_mounts to WorkPoolSpec in aiplatform v1 custom_job.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add nfs_mounts to WorkPoolSpec in aiplatform v1beta1 custom_job.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add Pandas DataFrame support to TabularDataset ([#1185](https://github.com/googleapis/python-aiplatform/issues/1185)) ([4fe4558](https://github.com/googleapis/python-aiplatform/commit/4fe4558ea0aaf73e3c0e9715ae90cb729a4c5678))
+* add PredictRequestResponseLoggingConfig to aiplatform v1beta1 endpoint.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add reserved_ip_ranges to CustomJobSpec in aiplatform v1 custom_job.proto ([#1165](https://github.com/googleapis/python-aiplatform/issues/1165)) ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add reserved_ip_ranges to CustomJobSpec in aiplatform v1beta1 custom_job.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* add template_metadata to PipelineJob in aiplatform v1beta1 pipeline_job.proto ([#1186](https://github.com/googleapis/python-aiplatform/issues/1186)) ([99aca4a](https://github.com/googleapis/python-aiplatform/commit/99aca4a9b0deeefd294cfd64fa3e247cc41e006c))
+* add version_id to Model in aiplatform v1beta1 model.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+* allow creating featurestore without online node ([#1180](https://github.com/googleapis/python-aiplatform/issues/1180)) ([3224ae3](https://github.com/googleapis/python-aiplatform/commit/3224ae3402e9493866dd4958d011a431968b9c2c))
+* Allow users to specify timestamp split for vertex forecasting ([#1187](https://github.com/googleapis/python-aiplatform/issues/1187)) ([ee49e00](https://github.com/googleapis/python-aiplatform/commit/ee49e004c8fbd0c8c27760b525c6e7431057a45e))
+* Make matching engine API public ([#1192](https://github.com/googleapis/python-aiplatform/issues/1192)) ([469db6b](https://github.com/googleapis/python-aiplatform/commit/469db6b08a9aa7fc64d8ea27f7e2e2fb2e9f643b))
+* rename Similarity to Examples, and similarity to examples in ExplanationParameters in aiplatform v1beta1 explanation.proto ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+
+
+### Documentation
+
+* fix type in docstring for map fields ([847ad78](https://github.com/googleapis/python-aiplatform/commit/847ad789e09aec14238a7476a3fa88729ce24d6f))
+
+## [1.12.1](https://github.com/googleapis/python-aiplatform/compare/v1.12.0...v1.12.1) (2022-04-20)
+
+
+### Features
+
+* Add endpoind_id arg to Endpoint#create ([#1168](https://github.com/googleapis/python-aiplatform/issues/1168)) ([4c21993](https://github.com/googleapis/python-aiplatform/commit/4c21993642b84d7595ead7a63424260deafaf43c))
+* add ModelEvaluation support ([#1167](https://github.com/googleapis/python-aiplatform/issues/1167)) ([10f95cd](https://github.com/googleapis/python-aiplatform/commit/10f95cde5e0282a99041ff2108111970f52379f3))
+
+
+### Bug Fixes
+
+* change default for create_request_timeout arg to None ([#1175](https://github.com/googleapis/python-aiplatform/issues/1175)) ([47791f7](https://github.com/googleapis/python-aiplatform/commit/47791f79c56a67be7503b5d5d4eb72dc409b18a0))
+
+
+### Documentation
+
+* endpoint.create => aiplatform.Endpoint.create ([#1153](https://github.com/googleapis/python-aiplatform/issues/1153)) ([1122a26](https://github.com/googleapis/python-aiplatform/commit/1122a26fd01d4c964055ca85a683de0c91867b6f))
+* update changelog headers ([#1164](https://github.com/googleapis/python-aiplatform/issues/1164)) ([c1e899d](https://github.com/googleapis/python-aiplatform/commit/c1e899dba3f57e515b1f1958e962f355276460c4))
+* update model code snippet order in README ([#1154](https://github.com/googleapis/python-aiplatform/issues/1154)) ([404d7f1](https://github.com/googleapis/python-aiplatform/commit/404d7f13d8666ea673743ab54928046eb64ee542))
+
+
+### Miscellaneous Chores
+
+* release 1.12.1 ([#1176](https://github.com/googleapis/python-aiplatform/issues/1176)) ([f98d92e](https://github.com/googleapis/python-aiplatform/commit/f98d92ecf7ad42fdbb095e65f98800bc6b2d3d12))
+
+## [1.12.0](https://github.com/googleapis/python-aiplatform/compare/v1.11.0...v1.12.0) (2022-04-07)
+
+
+### Features
+
+* add categorical_threshold_config to FeaturestoreMonitoringConfig in aiplatform v1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add categorical_threshold_config to FeaturestoreMonitoringConfig in aiplatform v1beta1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add disable_monitoring to Feature in aiplatform v1 feature.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add disable_monitoring to Feature in aiplatform v1beta1 feature.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* Add done method for pipeline, training, and batch prediction jobs ([#1062](https://github.com/googleapis/python-aiplatform/issues/1062)) ([f3338fc](https://github.com/googleapis/python-aiplatform/commit/f3338fcd4f51072ee86b765ee580cfe3c4b222ce))
+* add import_features_analysis to FeaturestoreMonitoringConfig in aiplatform v1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add import_features_analysis to FeaturestoreMonitoringConfig in aiplatform v1beta1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add ImportModelEvaluation in aiplatform v1 model_service.proto ([#1105](https://github.com/googleapis/python-aiplatform/issues/1105)) ([ef5930c](https://github.com/googleapis/python-aiplatform/commit/ef5930c58838ce51f92ef1acb941f5141c83faad))
+* add monitoring_config to EntityType in aiplatform v1 entity_type.proto ([#1077](https://github.com/googleapis/python-aiplatform/issues/1077)) ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add monitoring_stats_anomalies to Feature in aiplatform v1 feature.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add monitoring_stats_anomalies to Feature in aiplatform v1beta1 feature.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add numerical_threshold_config to FeaturestoreMonitoringConfig in aiplatform v1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add numerical_threshold_config to FeaturestoreMonitoringConfig in aiplatform v1beta1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add objective to MonitoringStatsSpec in aiplatform v1 featurestore_service.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add objective to MonitoringStatsSpec in aiplatform v1beta1 featurestore_service.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add PredictRequestResponseLoggingConfig to Endpoint in aiplatform v1 endpoint.proto ([#1072](https://github.com/googleapis/python-aiplatform/issues/1072)) ([be0ccc4](https://github.com/googleapis/python-aiplatform/commit/be0ccc488dac22128be317ca40337d6b93af0906))
+* add staleness_days to SnapshotAnalysis in aiplatform v1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* add staleness_days to SnapshotAnalysis in aiplatform v1beta1 featurestore_monitoring.proto ([38f3711](https://github.com/googleapis/python-aiplatform/commit/38f3711bd76bbcfe4ce48739bb11049e2711d47f))
+* Add support for Vertex Tables Q1 regions ([#1065](https://github.com/googleapis/python-aiplatform/issues/1065)) ([6383d4f](https://github.com/googleapis/python-aiplatform/commit/6383d4f20f1ab0a7634c1028cb9f270e91c31d2a))
+* add timeout arg across SDK ([#1099](https://github.com/googleapis/python-aiplatform/issues/1099)) ([184f7f3](https://github.com/googleapis/python-aiplatform/commit/184f7f327aa00b4c8d1acc24dcb1c4c4be6c5bcc))
+* Add timeout arguments to Endpoint.predict and Endpoint.explain ([#1094](https://github.com/googleapis/python-aiplatform/issues/1094)) ([cc59e60](https://github.com/googleapis/python-aiplatform/commit/cc59e60193a72bb57d699cabea03ab7bdd386b0e))
+* Made display_name parameter optional for most calls ([#882](https://github.com/googleapis/python-aiplatform/issues/882)) ([400b760](https://github.com/googleapis/python-aiplatform/commit/400b7608afeaca9a36936cabd402c5322eb9345b))
+* **sdk:** enable loading both JSON and YAML pipelines IR ([#1089](https://github.com/googleapis/python-aiplatform/issues/1089)) ([f2e70b1](https://github.com/googleapis/python-aiplatform/commit/f2e70b1563171b5a92a2c40edf29ae373bbeb175))
+* **v1beta1:** add `service_account` to `BatchPredictionJob` in `batch_prediction_job.proto` ([#1084](https://github.com/googleapis/python-aiplatform/issues/1084)) ([b7a5177](https://github.com/googleapis/python-aiplatform/commit/b7a517731bc8127d4186838bfb88fa883b2be853))
+
+
+### Bug Fixes
+
+* add resource manager utils to get project ID from project number ([#1068](https://github.com/googleapis/python-aiplatform/issues/1068)) ([f10a1d4](https://github.com/googleapis/python-aiplatform/commit/f10a1d4280c3e653c9f4795f0423bf07a23acdf9))
+* add self.wait() in operations after optional_sync supported resource creation ([#1083](https://github.com/googleapis/python-aiplatform/issues/1083)) ([79aeec1](https://github.com/googleapis/python-aiplatform/commit/79aeec1380068318398851b2a7b2fd6ddee7fa8b))
+* Don't throw exception when getting representation of unrun GCA objects ([#1071](https://github.com/googleapis/python-aiplatform/issues/1071)) ([c9ba060](https://github.com/googleapis/python-aiplatform/commit/c9ba0603e6a8e3d772af67367242aad7a18e03c8))
+* Fix import error string showing wrong pip install path ([#1076](https://github.com/googleapis/python-aiplatform/issues/1076)) ([74ffa19](https://github.com/googleapis/python-aiplatform/commit/74ffa19e7d540f6bb5f21d2335c2a5d23cc49ee2))
+* Fixed getting project ID when running on Vertex AI; Fixes [#852](https://github.com/googleapis/python-aiplatform/issues/852) ([#943](https://github.com/googleapis/python-aiplatform/issues/943)) ([876cb33](https://github.com/googleapis/python-aiplatform/commit/876cb33a407cfea5c965e4f348056b147b1d16c3))
+* Give aiplatform logging its own log namespace, let the user configure their own root logger ([#1081](https://github.com/googleapis/python-aiplatform/issues/1081)) ([fb78243](https://github.com/googleapis/python-aiplatform/commit/fb782434d456f41c6c6bd6664b203cebb53131b8))
+* Honoring the model's supported_deployment_resources_types ([#865](https://github.com/googleapis/python-aiplatform/issues/865)) ([db34b85](https://github.com/googleapis/python-aiplatform/commit/db34b85aaf211ca491313d2b8ae2a45253109614))
+* missing reference to logged_web_access_uris ([#1056](https://github.com/googleapis/python-aiplatform/issues/1056)) ([198a1b5](https://github.com/googleapis/python-aiplatform/commit/198a1b5753f509c9137a8d9e9b56d68e6e386563))
+* system tests failure from test_upload_and_deploy_xgboost_model ([#1149](https://github.com/googleapis/python-aiplatform/issues/1149)) ([c8422a9](https://github.com/googleapis/python-aiplatform/commit/c8422a9b807e092f2d48e7f3fa8b40c8724cc028))
+
+
+### Documentation
+
+* fix CustomContainerTrainingJob example in docstring ([#1101](https://github.com/googleapis/python-aiplatform/issues/1101)) ([d2fb9db](https://github.com/googleapis/python-aiplatform/commit/d2fb9db095d1acb15894df3d0a5e66128ce8f14e))
+* improve bigquery_destination_prefix docstring ([#1098](https://github.com/googleapis/python-aiplatform/issues/1098)) ([a46df64](https://github.com/googleapis/python-aiplatform/commit/a46df64ab99aee8d7e47b44394a234243dc2a0f8))
+* Include time dependency in documentation for weight, time, and target columns. ([#1102](https://github.com/googleapis/python-aiplatform/issues/1102)) ([52273c2](https://github.com/googleapis/python-aiplatform/commit/52273c2108c9bb24eadab214036f2ef93b847321))
+* **samples:** read, import, batch_serve, batch_create features ([#1046](https://github.com/googleapis/python-aiplatform/issues/1046)) ([80dd40d](https://github.com/googleapis/python-aiplatform/commit/80dd40dcb830ece3b5442d60834357ada6583204))
+* Update AutoML Video docstring ([#987](https://github.com/googleapis/python-aiplatform/issues/987)) ([6002d5d](https://github.com/googleapis/python-aiplatform/commit/6002d5d9bf24542f9f3f844e469bc3f8ad9636ec))
+
+## [1.11.0](https://github.com/googleapis/python-aiplatform/compare/v1.10.0...v1.11.0) (2022-03-03)
+
+
+### Features
+
+* add additional_experiement flag in the tables and forecasting training job ([#979](https://github.com/googleapis/python-aiplatform/issues/979)) ([5fe59a4](https://github.com/googleapis/python-aiplatform/commit/5fe59a4015882d56c22f9973aff888966dd53a2e))
+* add TPU_V2 & TPU_V3 values to AcceleratorType in aiplatform v1/v1beta1 accelerator_type.proto ([#1010](https://github.com/googleapis/python-aiplatform/issues/1010)) ([09c2e8a](https://github.com/googleapis/python-aiplatform/commit/09c2e8a368c6d265d99acfb12addd5ba6f1a50e6))
+* Added scheduling to CustomTrainingJob, CustomPythonPackageTrainingJob, CustomContainerTrainingJob ([#970](https://github.com/googleapis/python-aiplatform/issues/970)) ([89078e0](https://github.com/googleapis/python-aiplatform/commit/89078e0d2a719e2b0d25ae36ecd06c356a5a33c9))
+
+
+### Bug Fixes
+
+* **deps:** allow google-cloud-storage < 3.0.0dev ([#1008](https://github.com/googleapis/python-aiplatform/issues/1008)) ([1c34154](https://github.com/googleapis/python-aiplatform/commit/1c341544e9bd94c6ff0ee41177565c8c078673a3))
+* **deps:** require google-api-core>=1.31.5, >=2.3.2 ([#1050](https://github.com/googleapis/python-aiplatform/issues/1050)) ([dfbd68a](https://github.com/googleapis/python-aiplatform/commit/dfbd68a79f1c892c4380405dd900deb6ac6574a6))
+* **deps:** require proto-plus>=1.15.0 ([dfbd68a](https://github.com/googleapis/python-aiplatform/commit/dfbd68a79f1c892c4380405dd900deb6ac6574a6))
+* enforce bq SchemaField field_type and mode using feature value_type ([#1019](https://github.com/googleapis/python-aiplatform/issues/1019)) ([095bea2](https://github.com/googleapis/python-aiplatform/commit/095bea23bc15a490ddbb1a8edac7f5db626bc659))
+* Fix create_lit_model_from_endpoint not accepting models that don't return a dictionary. ([#1020](https://github.com/googleapis/python-aiplatform/issues/1020)) ([b9a057d](https://github.com/googleapis/python-aiplatform/commit/b9a057d001deb8727cb725d44bb5528dce330653))
+* loosen assertions for system test featurestore ([#1040](https://github.com/googleapis/python-aiplatform/issues/1040)) ([2ba404f](https://github.com/googleapis/python-aiplatform/commit/2ba404f8bfbccd7a18ef613417912ed94882c4bd))
+* remove empty scripts kwarg in setup.py ([#1014](https://github.com/googleapis/python-aiplatform/issues/1014)) ([ef3fcc8](https://github.com/googleapis/python-aiplatform/commit/ef3fcc86fb3808b37706470c8c49903ec3a302fb))
+* show logs when TFX pipelines are submitted ([#976](https://github.com/googleapis/python-aiplatform/issues/976)) ([c10923b](https://github.com/googleapis/python-aiplatform/commit/c10923b47b9b9941d14ae2c5398348d971a23f9d))
+* update system test_model_upload to use BUILD_SPECIFIC_GCP_PROJECT ([#1043](https://github.com/googleapis/python-aiplatform/issues/1043)) ([e7d2719](https://github.com/googleapis/python-aiplatform/commit/e7d27193f323f88f4238206ecb380d746d98df31))
+
+
+### Documentation
+
+* **samples:** add samples to create/delete featurestore ([#980](https://github.com/googleapis/python-aiplatform/issues/980)) ([5ee6354](https://github.com/googleapis/python-aiplatform/commit/5ee6354a12c6422015acb81caef32d6d2f52c838))
+* **samples:** added create feature and create entity type samples and tests ([#984](https://github.com/googleapis/python-aiplatform/issues/984)) ([d221e6b](https://github.com/googleapis/python-aiplatform/commit/d221e6bebd7fb98a8c6e3f3b8ae507f2f214128f))
+
+## [1.10.0](https://github.com/googleapis/python-aiplatform/compare/v1.9.0...v1.10.0) (2022-02-07)
+
+
+### Features
+
+* _TrainingScriptPythonPackager to support folders ([#812](https://github.com/googleapis/python-aiplatform/issues/812)) ([3aec6a7](https://github.com/googleapis/python-aiplatform/commit/3aec6a7b8f26ef2a5b378a6224d6402e3b42c917))
+* add dedicated_resources to DeployedIndex in aiplatform v1beta1 index_endpoint.proto feat: add Scaling to OnlineServingConfig in aiplatform v1beta1 featurestore.proto chore: sort imports ([#991](https://github.com/googleapis/python-aiplatform/issues/991)) ([7a7f0d4](https://github.com/googleapis/python-aiplatform/commit/7a7f0d45f3d08c93b11fcd2c5a265a8db4b0c890))
+* add dedicated_resources to DeployedIndex message in aiplatform v1 index_endpoint.proto chore: sort imports ([#990](https://github.com/googleapis/python-aiplatform/issues/990)) ([a814923](https://github.com/googleapis/python-aiplatform/commit/a8149233bcd857e75700c6ec7d29c0aabf1687c1))
+* Add XAI SDK integration to TensorFlow models with LIT integration ([#917](https://github.com/googleapis/python-aiplatform/issues/917)) ([ea2b5cf](https://github.com/googleapis/python-aiplatform/commit/ea2b5cfbcafead1c63009fda10bd44a00d560efb))
+* Added `aiplatform.Model.update` method ([#952](https://github.com/googleapis/python-aiplatform/issues/952)) ([44e208a](https://github.com/googleapis/python-aiplatform/commit/44e208a8dbf082e770373d58c31b3ad3e8b39f4f))
+* Enable europe-west6 and northamerica-northeast2 regions ([0f6b670](https://github.com/googleapis/python-aiplatform/commit/0f6b6701e96fb0ec345e81560d03059a7900160f))
+* enable feature store batch serve to BigQuery and GCS for csv and tfrecord ([#919](https://github.com/googleapis/python-aiplatform/issues/919)) ([c840728](https://github.com/googleapis/python-aiplatform/commit/c840728e503eea3300e9629405978e28c6aafec7))
+* enable feature store batch serve to Pandas DataFrame; fix: read instances uri for batch serve ([#983](https://github.com/googleapis/python-aiplatform/issues/983)) ([e0fec36](https://github.com/googleapis/python-aiplatform/commit/e0fec36686e373c13acca3203372572c760c7af4))
+* enable feature store online serving ([#918](https://github.com/googleapis/python-aiplatform/issues/918)) ([b8f5f82](https://github.com/googleapis/python-aiplatform/commit/b8f5f82ae43edfb933305a074c315e2f3239b4b1))
+* enable ingest from pd.DataFrame ([#977](https://github.com/googleapis/python-aiplatform/issues/977)) ([9289f2d](https://github.com/googleapis/python-aiplatform/commit/9289f2d3ce424f3f9754a3dd23883e25dec1300f))
+* Open LIT with a deployed model ([#963](https://github.com/googleapis/python-aiplatform/issues/963)) ([ea16849](https://github.com/googleapis/python-aiplatform/commit/ea16849f936d7a2e8402fd235decefe5972685ed))
+
+
+### Bug Fixes
+
+* Fixed BigQuery datasets that have colon in URI ([#855](https://github.com/googleapis/python-aiplatform/issues/855)) ([153578f](https://github.com/googleapis/python-aiplatform/commit/153578f19d57db96e3674b2d797c5352c107f936))
+* Fixed integration test for model.upload ([#975](https://github.com/googleapis/python-aiplatform/issues/975)) ([0ca3747](https://github.com/googleapis/python-aiplatform/commit/0ca374769f922fd427c5b6f58c9ce1ab40f18d18))
+* rename teardown fixture ([#1004](https://github.com/googleapis/python-aiplatform/issues/1004)) ([fcd0096](https://github.com/googleapis/python-aiplatform/commit/fcd00969dbbbf06887dfdbaa6bc65b135c24f95f))
+
+
+### Documentation
+
+* **samples:** replace deprecated fields in create_training_pipeline_tabular_forecasting_sample.py ([#981](https://github.com/googleapis/python-aiplatform/issues/981)) ([9ebc972](https://github.com/googleapis/python-aiplatform/commit/9ebc972bba972b1e1920db422ed28a721e90329d))
+
+## [1.9.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.8.1...v1.9.0) (2021-12-29)
+
+
+### Features
+
+* add create in Featurestore, EntityType, Feature; add create_entity_type in Featurestore; add create_feature, batch_create_features in EntityType; add ingest_from_* for bq and gcs in EntityType; add and update delete with force delete nested resources ([#872](https://www.github.com/googleapis/python-aiplatform/issues/872)) ([ba11c3d](https://www.github.com/googleapis/python-aiplatform/commit/ba11c3d3cd8d3869e2deb3207a8698fa7ce284ec))
+* Add LIT methods for Pandas DataFrame and TensorFlow saved model. ([#874](https://www.github.com/googleapis/python-aiplatform/issues/874)) ([03cf301](https://www.github.com/googleapis/python-aiplatform/commit/03cf301989a5802b122803eac7a2d03f2d1769fb))
+* Add support to create TensorboardExperiment ([#909](https://www.github.com/googleapis/python-aiplatform/issues/909)) ([96ce738](https://www.github.com/googleapis/python-aiplatform/commit/96ce7387ac58e0ec7cb6a7f6d6a6e422eae5da96))
+* Add support to create TensorboardRun ([#912](https://www.github.com/googleapis/python-aiplatform/issues/912)) ([8df74a2](https://www.github.com/googleapis/python-aiplatform/commit/8df74a29df0adb95fff5500fcc9d7a025012ab5e))
+
+
+### Bug Fixes
+
+* Fix timestamp proto util to default to timestamp at call time. ([#933](https://www.github.com/googleapis/python-aiplatform/issues/933)) ([d72a254](https://www.github.com/googleapis/python-aiplatform/commit/d72a254e97cf74f3fdd55a32a4af86737243593a))
+* Improve handling of undeploying model without redistributing remaining traffic ([#898](https://www.github.com/googleapis/python-aiplatform/issues/898)) ([8a8a4fa](https://www.github.com/googleapis/python-aiplatform/commit/8a8a4faa667bde2a4df04afa23a6dd5b1856f958))
+* issues/192254729 ([#914](https://www.github.com/googleapis/python-aiplatform/issues/914)) ([3ec620c](https://www.github.com/googleapis/python-aiplatform/commit/3ec620c64bd60ceb5b89918200e11e3fbff67370))
+* issues/192254729 ([#915](https://www.github.com/googleapis/python-aiplatform/issues/915)) ([0f22ff6](https://www.github.com/googleapis/python-aiplatform/commit/0f22ff61460a3f2bd55d2c10c4ee06e582f03944))
+* use open_in_new_tab in the render method. ([#926](https://www.github.com/googleapis/python-aiplatform/issues/926)) ([04618e0](https://www.github.com/googleapis/python-aiplatform/commit/04618e0563b8588eec2ccd8342c6085ca08b5adb))
+
+## [1.8.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.8.0...v1.8.1) (2021-12-14)
+
+
+### Bug Fixes
+
+* add clarity to param model_name ([#888](https://www.github.com/googleapis/python-aiplatform/issues/888)) ([1d81783](https://www.github.com/googleapis/python-aiplatform/commit/1d81783b2f914dd7606ee884ca31c1a594e5135f))
+* add clarity to parameters per user feedback ([#886](https://www.github.com/googleapis/python-aiplatform/issues/886)) ([37ee0a1](https://www.github.com/googleapis/python-aiplatform/commit/37ee0a1dc6e0105e19aca18f44995a352bfc40cb))
+* add param for multi-label per user's feedback ([#887](https://www.github.com/googleapis/python-aiplatform/issues/887)) ([fda942f](https://www.github.com/googleapis/python-aiplatform/commit/fda942ffbe009077b47f36aad1c29603a451e38b))
+* add support for API base path overriding ([#908](https://www.github.com/googleapis/python-aiplatform/issues/908)) ([45c4086](https://www.github.com/googleapis/python-aiplatform/commit/45c4086dd07dd7d3d3b7417196ff61a7107d8a1a))
+* Important the correct constants and use v1 for tensorboard experiments ([#905](https://www.github.com/googleapis/python-aiplatform/issues/905)) ([48c2bf1](https://www.github.com/googleapis/python-aiplatform/commit/48c2bf1ea2fa42afea1b5d419527bfb8e49e0ac0))
+* incorrect uri for IOD yaml ([#889](https://www.github.com/googleapis/python-aiplatform/issues/889)) ([e108ef8](https://www.github.com/googleapis/python-aiplatform/commit/e108ef8250c77c8a8edeccb6b601cbe0b0380c89))
+* Minor docstring and snippet fixes ([#873](https://www.github.com/googleapis/python-aiplatform/issues/873)) ([578e06d](https://www.github.com/googleapis/python-aiplatform/commit/578e06df481c3d60074a7b8e9365f8361b04e32b))
+
+
+### Documentation
+
+* Update references to containers and notebook samples. ([#890](https://www.github.com/googleapis/python-aiplatform/issues/890)) ([67fa1f1](https://www.github.com/googleapis/python-aiplatform/commit/67fa1f179af66686339d797e5b368e96816ed1c5))
+* Updated docstrings with exception error classes ([#894](https://www.github.com/googleapis/python-aiplatform/issues/894)) ([f9aecd2](https://www.github.com/googleapis/python-aiplatform/commit/f9aecd22fe08a97e45187b4d11c755ac3b9dfadd))
+
+## [1.8.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.7.1...v1.8.0) (2021-12-03)
+
+
+### Features
+
+* Add cloud profiler to training_utils ([6d5c7c4](https://www.github.com/googleapis/python-aiplatform/commit/6d5c7c42d1c352f161e4738c6dbbf540a032017b))
+* add enable_private_service_connect field to Endpoint feat: add id field to DeployedModel feat: add service_attachment field to PrivateEndpoints feat: add endpoint_id to CreateEndpointRequest and method signature to CreateEndpoint feat: add method... ([#878](https://www.github.com/googleapis/python-aiplatform/issues/878)) ([ca813be](https://www.github.com/googleapis/python-aiplatform/commit/ca813be08ec2620380b5a12b0d6cdc079e27ba79))
+* add enable_private_service_connect field to Endpoint feat: add id field to DeployedModel feat: add service_attachment field to PrivateEndpoints feat: add endpoint_id to CreateEndpointRequest and method signature to CreateEndpoint feat: add method... ([#879](https://www.github.com/googleapis/python-aiplatform/issues/879)) ([47e93b2](https://www.github.com/googleapis/python-aiplatform/commit/47e93b20843f30805b73cd6db214c8743f8bfc97))
+* add featurestore module including Featurestore, EntityType, and Feature classes; add get, update, delete, list methods in all featurestore classes; add search method in Feature class ([#850](https://www.github.com/googleapis/python-aiplatform/issues/850)) ([66745a6](https://www.github.com/googleapis/python-aiplatform/commit/66745a6ce13fb8b32dd7fbf3eb86e71bd291869b))
+* Add prediction container URI builder method ([#805](https://www.github.com/googleapis/python-aiplatform/issues/805)) ([91dd3c0](https://www.github.com/googleapis/python-aiplatform/commit/91dd3c0d5de72fac5b1dc8a9bc23d6cb431061a4))
+* default to custom job display name if experiment name looks like a custom job ID ([#833](https://www.github.com/googleapis/python-aiplatform/issues/833)) ([8b9376e](https://www.github.com/googleapis/python-aiplatform/commit/8b9376e9c961a751799f5b80d1b19917c8c353f8))
+* Support uploading local models ([#779](https://www.github.com/googleapis/python-aiplatform/issues/779)) ([bffbd9d](https://www.github.com/googleapis/python-aiplatform/commit/bffbd9d359edb099e661736a0c77269bb3a0c746))
+* Tensorboard v1 protos release ([#847](https://www.github.com/googleapis/python-aiplatform/issues/847)) ([e0fc3d9](https://www.github.com/googleapis/python-aiplatform/commit/e0fc3d9e4e8a7911f21671ea49818c5f84798d12))
+* updating Tensorboard related code to use v1 ([#851](https://www.github.com/googleapis/python-aiplatform/issues/851)) ([b613b26](https://www.github.com/googleapis/python-aiplatform/commit/b613b264524aaab2cb65e63a5487770736faa7c8))
+* Upgrade Tensorboard from v1beta1 to v1 ([#849](https://www.github.com/googleapis/python-aiplatform/issues/849)) ([c40ec85](https://www.github.com/googleapis/python-aiplatform/commit/c40ec85e1fca2bee6813f52712d063f96264ec2c))
+
+
+### Bug Fixes
+
+* Import error for cloud_profiler ([#869](https://www.github.com/googleapis/python-aiplatform/issues/869)) ([0f124e9](https://www.github.com/googleapis/python-aiplatform/commit/0f124e93a1ddead16c0018970f34e45c73d5ed81))
+* Support multiple instances in custom predict sample ([#857](https://www.github.com/googleapis/python-aiplatform/issues/857)) ([8cb4839](https://www.github.com/googleapis/python-aiplatform/commit/8cb483918bdbaeae34935eef2b3cd997c1ae89a3))
+
+
+### Documentation
+
+* Added comment for evaluation_id to python examples ([#860](https://www.github.com/googleapis/python-aiplatform/issues/860)) ([004bf5f](https://www.github.com/googleapis/python-aiplatform/commit/004bf5fa4cb2d66e36de7ec52dee8e2c8dd438ee))
+* Reverted IDs in model_service snippets test ([#871](https://www.github.com/googleapis/python-aiplatform/issues/871)) ([da747b5](https://www.github.com/googleapis/python-aiplatform/commit/da747b5ffca3c12b8d64bc80bfe93da5afde0d43))
+* Update name of BQ source parameter in samples ([#859](https://www.github.com/googleapis/python-aiplatform/issues/859)) ([f11b598](https://www.github.com/googleapis/python-aiplatform/commit/f11b598f9069f77e86631ada53941876aea010bc))
+
+## [1.7.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.7.0...v1.7.1) (2021-11-16)
+
+
+### Features
+
+* Add support for new Vertex regions ([#811](https://www.github.com/googleapis/python-aiplatform/issues/811)) ([8d04138](https://www.github.com/googleapis/python-aiplatform/commit/8d0413880486d03314ecab80347a713318c6944a))
+
+
+### Bug Fixes
+
+* add parameters_value in PipelineJob for schema > 2.0.0 ([#817](https://www.github.com/googleapis/python-aiplatform/issues/817)) ([900a449](https://www.github.com/googleapis/python-aiplatform/commit/900a44962ac85608dbcb3d23049db160d49d842a))
+* exclude support for python 3.10 ([#831](https://www.github.com/googleapis/python-aiplatform/issues/831)) ([0301a1d](https://www.github.com/googleapis/python-aiplatform/commit/0301a1de5719031c6c826fe4887ff5fb6bcfa956))
+
+
+### Miscellaneous Chores
+
+* release 1.7.1 ([#845](https://www.github.com/googleapis/python-aiplatform/issues/845)) ([ca04de6](https://www.github.com/googleapis/python-aiplatform/commit/ca04de6a95f8b22d0161e250d8d4314a35becfab))
+
+## [1.7.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.6.2...v1.7.0) (2021-11-06)
+
+
+### Features
+
+* Adds support for `google.protobuf.Value` pipeline parameters in the `parameter_values` field ([#807](https://www.github.com/googleapis/python-aiplatform/issues/807)) ([c97199d](https://www.github.com/googleapis/python-aiplatform/commit/c97199dd2cb712ef436ee9cbf6b8add27b42b174))
+* Adds support for `google.protobuf.Value` pipeline parameters in the `parameter_values` field ([#808](https://www.github.com/googleapis/python-aiplatform/issues/808)) ([726b620](https://www.github.com/googleapis/python-aiplatform/commit/726b620bea1223c80225c9a3c2b54342e9c14052))
+* PipelineJob switch to v1 API from v1beta1 API ([#750](https://www.github.com/googleapis/python-aiplatform/issues/750)) ([8db7e0c](https://www.github.com/googleapis/python-aiplatform/commit/8db7e0ca4e796fea47c1bdf4c0fccd514f2dd8c2))
+
+
+### Bug Fixes
+
+* Correct PipelineJob credentials description ([#816](https://www.github.com/googleapis/python-aiplatform/issues/816)) ([49aaa87](https://www.github.com/googleapis/python-aiplatform/commit/49aaa8719a3daabf7e0d23fa1cd1d64c19159a83))
+* Fixed docstrings for Dataset in AutoMLForecastingTrainingJob ([760887b](https://www.github.com/googleapis/python-aiplatform/commit/760887b196884707473896def9e8b69c9fc77423))
+
+
+### Documentation
+
+* Fix pydocs README to be consistent with repo README ([#821](https://www.github.com/googleapis/python-aiplatform/issues/821)) ([95dbd60](https://www.github.com/googleapis/python-aiplatform/commit/95dbd6020ee8f3037b0834eb39312b5d7e5fd8e1))
+* Update sample with feedback from b/191251050 ([#818](https://www.github.com/googleapis/python-aiplatform/issues/818)) ([6b2d938](https://www.github.com/googleapis/python-aiplatform/commit/6b2d93834734b6789c13ef3782b1b3632f5c6133))
+
+## [1.6.2](https://www.github.com/googleapis/python-aiplatform/compare/v1.6.1...v1.6.2) (2021-11-01)
+
+
+### Features
+
+* Add PipelineJob.submit to create PipelineJob without monitoring it's completion. ([#798](https://www.github.com/googleapis/python-aiplatform/issues/798)) ([7ab05d5](https://www.github.com/googleapis/python-aiplatform/commit/7ab05d5e127636d96365b7ea408974ccd6c2f0fe))
+* support new protobuf value param types for Pipeline Job client ([#797](https://www.github.com/googleapis/python-aiplatform/issues/797)) ([2fc05ca](https://www.github.com/googleapis/python-aiplatform/commit/2fc05cab03a2c7f8462b234b02d43bc7581ba845))
+
+
+### Bug Fixes
+
+* Add retries when polling during monitoring runs ([#786](https://www.github.com/googleapis/python-aiplatform/issues/786)) ([45401c0](https://www.github.com/googleapis/python-aiplatform/commit/45401c09f23ed616a7ca84b3d7f53b8a1db21c7c))
+* use version.py for versioning ([#804](https://www.github.com/googleapis/python-aiplatform/issues/804)) ([514031f](https://www.github.com/googleapis/python-aiplatform/commit/514031fce90b6e4606279d4903dc93b0f18b9f2a))
+* Widen system test timeout, handle tearing down failed training pipelines ([#791](https://www.github.com/googleapis/python-aiplatform/issues/791)) ([78879e2](https://www.github.com/googleapis/python-aiplatform/commit/78879e2482cac7ef5520f1d7fe900768147b948e))
+
+
+### Miscellaneous Chores
+
+* release 1.6.2 ([#809](https://www.github.com/googleapis/python-aiplatform/issues/809)) ([e50b049](https://www.github.com/googleapis/python-aiplatform/commit/e50b0497574411a9c7462d76dca489281ee48d83))
+
+## [1.6.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.6.0...v1.6.1) (2021-10-25)
+
+
+### Features
+
+* Add debugging terminal support for CustomJob, HyperparameterTun⊠([#699](https://www.github.com/googleapis/python-aiplatform/issues/699)) ([2deb505](https://www.github.com/googleapis/python-aiplatform/commit/2deb50502ae2bb8ba3f97d69b06b72b7625639a4))
+* add support for python 3.10 ([#769](https://www.github.com/googleapis/python-aiplatform/issues/769)) ([8344804](https://www.github.com/googleapis/python-aiplatform/commit/83448044508f5feb052ae7fc5a5a7ca917cee0d1))
+* Add training_utils folder and environment_variables for training ([141c008](https://www.github.com/googleapis/python-aiplatform/commit/141c008759aefe56a41e1eac654739c509d9754d))
+* enable reduction server ([#741](https://www.github.com/googleapis/python-aiplatform/issues/741)) ([8ef0ded](https://www.github.com/googleapis/python-aiplatform/commit/8ef0ded034db797adb4d458eba43537992d822bd))
+* enabling AutoML Forecasting training response to include BigQuery location of exported evaluated examples ([#657](https://www.github.com/googleapis/python-aiplatform/issues/657)) ([c1c2326](https://www.github.com/googleapis/python-aiplatform/commit/c1c2326b2342ab1b6f4c4ce3852e63376eae740d))
+* **PipelineJob:** allow PipelineSpec as param ([#774](https://www.github.com/googleapis/python-aiplatform/issues/774)) ([f90a1bd](https://www.github.com/googleapis/python-aiplatform/commit/f90a1bd775daa0892e16fd82fc1738fa9a912ec7))
+* pre batch creating TensorboardRuns and TensorboardTimeSeries in one_shot mode to speed up uploading ([#772](https://www.github.com/googleapis/python-aiplatform/issues/772)) ([c9f68c6](https://www.github.com/googleapis/python-aiplatform/commit/c9f68c6e840ba3cda04080623dfbcba6945d53e8))
+
+
+### Bug Fixes
+
+* cast resource labels to dict type ([#783](https://www.github.com/googleapis/python-aiplatform/issues/783)) ([255edc9](https://www.github.com/googleapis/python-aiplatform/commit/255edc92dc897619ddd705463aefb8a1723ae8cd))
+* Remove sync parameter from create_endpoint_sample ([#695](https://www.github.com/googleapis/python-aiplatform/issues/695)) ([0477f5a](https://www.github.com/googleapis/python-aiplatform/commit/0477f5a10ba1048e64c11fc3d7e1e375b19a10fe))
+
+
+### Miscellaneous Chores
+
+* release 1.6.1 ([#789](https://www.github.com/googleapis/python-aiplatform/issues/789)) ([4520d35](https://www.github.com/googleapis/python-aiplatform/commit/4520d350beb756549304de60d62ff637bb1807c5))
+
+## [1.6.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.5.0...v1.6.0) (2021-10-12)
+
+
+### Features
+
+* add featurestore service to aiplatform v1 ([#765](https://www.github.com/googleapis/python-aiplatform/issues/765)) ([68c88e4](https://www.github.com/googleapis/python-aiplatform/commit/68c88e48f62d5c2ff561862ba810a48389f7e41a))
+* Add one shot profile uploads to tensorboard uploader. ([#704](https://www.github.com/googleapis/python-aiplatform/issues/704)) ([a83f253](https://www.github.com/googleapis/python-aiplatform/commit/a83f2535b31e2aaff0306c7290265b864b9ddb40))
+* Added column_specs, training_encryption_spec_key_name, model_encryption_spec_key_name to AutoMLForecastingTrainingJob.init and various split methods to AutoMLForecastingTrainingJob.run ([#647](https://www.github.com/googleapis/python-aiplatform/issues/647)) ([7cb6976](https://www.github.com/googleapis/python-aiplatform/commit/7cb69764e0f9be9ca0fcb1641f4dc90e3b306bed))
+* Lazy load Endpoint class ([#655](https://www.github.com/googleapis/python-aiplatform/issues/655)) ([c795c6f](https://www.github.com/googleapis/python-aiplatform/commit/c795c6fbb87c4f71845cfbd2647c1adbc029bcef))
+
+## [1.5.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.4.3...v1.5.0) (2021-09-30)
+
+
+### Features
+
+* Add data plane code snippets for feature store service ([#713](https://www.github.com/googleapis/python-aiplatform/issues/713)) ([e3ea683](https://www.github.com/googleapis/python-aiplatform/commit/e3ea683bf754832340853a15bdb0a0662500a70f))
+* add flaky test diagnostic script ([#734](https://www.github.com/googleapis/python-aiplatform/issues/734)) ([09e48de](https://www.github.com/googleapis/python-aiplatform/commit/09e48de8b79fb5d601169663c9a8e1c33883f1cf))
+* add vizier service to aiplatform v1 BUILD.bazel ([#731](https://www.github.com/googleapis/python-aiplatform/issues/731)) ([1a580ae](https://www.github.com/googleapis/python-aiplatform/commit/1a580aec158b5e25b94f27a6a9daa3943124c485))
+* code snippets for feature store control plane ([#709](https://www.github.com/googleapis/python-aiplatform/issues/709)) ([8e06ced](https://www.github.com/googleapis/python-aiplatform/commit/8e06ced83ed2cc480d869318c4debef9c28ad214))
+* Updating the Tensorboard uploader to use the new batch write API so it runs more efficiently ([#710](https://www.github.com/googleapis/python-aiplatform/issues/710)) ([9d1b01a](https://www.github.com/googleapis/python-aiplatform/commit/9d1b01a91dc077bfe8edf023216b65b826d67d5f))
+
+
+### Bug Fixes
+
+* [#677](https://www.github.com/googleapis/python-aiplatform/issues/677) ([#728](https://www.github.com/googleapis/python-aiplatform/issues/728)) ([7f548e4](https://www.github.com/googleapis/python-aiplatform/commit/7f548e4b5322055a3c2befcdc9d4eef1bc2278ca))
+* **PipelineJob:** use name as output only field ([#719](https://www.github.com/googleapis/python-aiplatform/issues/719)) ([1c84464](https://www.github.com/googleapis/python-aiplatform/commit/1c84464e3130f9db81cd341306b334f9a490587f))
+* use the project id from BQ dataset instead of the default project id ([#717](https://www.github.com/googleapis/python-aiplatform/issues/717)) ([e87a255](https://www.github.com/googleapis/python-aiplatform/commit/e87a255705a5d04ade79f12c706dc842c0228118))
+
+## [1.4.3](https://www.github.com/googleapis/python-aiplatform/compare/v1.4.2...v1.4.3) (2021-09-17)
+
+
+### Features
+
+* **PipelineJob:** support dict, list, bool typed input parameters fr⊠([#693](https://www.github.com/googleapis/python-aiplatform/issues/693)) ([243b75c](https://www.github.com/googleapis/python-aiplatform/commit/243b75c2655beeef47848410a40d86a072428ac3))
+
+
+### Bug Fixes
+
+* Update milli node_hours for image training ([#663](https://www.github.com/googleapis/python-aiplatform/issues/663)) ([64768c3](https://www.github.com/googleapis/python-aiplatform/commit/64768c3591f648932e023713d2a728ce5318bb8b))
+* XAI Metadata compatibility with Model.upload ([#705](https://www.github.com/googleapis/python-aiplatform/issues/705)) ([f0570cb](https://www.github.com/googleapis/python-aiplatform/commit/f0570cb999f024ca96e7daaa102c81b681c2a575))
+
+
+### Miscellaneous Chores
+
+* release 1.4.3 ([#715](https://www.github.com/googleapis/python-aiplatform/issues/715)) ([b610486](https://www.github.com/googleapis/python-aiplatform/commit/b6104868161a236fc5585855b5948a5e3294aea2))
+
+## [1.4.2](https://www.github.com/googleapis/python-aiplatform/compare/v1.4.1...v1.4.2) (2021-09-10)
+
+
+### Features
+
+* add explanation metadata `get_metadata_protobuf` for reuse ([#672](https://www.github.com/googleapis/python-aiplatform/issues/672)) ([efb6d18](https://www.github.com/googleapis/python-aiplatform/commit/efb6d18f868086bc53aceab60942eb837ced65b7))
+
+
+## [1.4.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.4.0...v1.4.1) (2021-09-07)
+
+
+### Features
+
+* add prediction service RPC RawPredict to aiplatform_v1beta1 feat: add tensorboard service RPCs to aiplatform_v1beta1: BatchCreateTensorboardRuns, BatchCreateTensorboardTimeSeries, WriteTensorboardExperimentData feat: add model_deployment_monitori... ([#670](https://www.github.com/googleapis/python-aiplatform/issues/670)) ([b73cd94](https://www.github.com/googleapis/python-aiplatform/commit/b73cd9485f8713ac42e7efa9bfd952f67368b778))
+* add Vizier service to aiplatform v1 ([#671](https://www.github.com/googleapis/python-aiplatform/issues/671)) ([179150a](https://www.github.com/googleapis/python-aiplatform/commit/179150aed80d1386993a07870fe34f2b637ded18))
+* add XAI, model monitoring, and index services to aiplatform v1 ([#668](https://www.github.com/googleapis/python-aiplatform/issues/668)) ([1fbce55](https://www.github.com/googleapis/python-aiplatform/commit/1fbce55fd846f473f41c16c1185be893e2376bdd))
+* Update tensorboard uploader to use Dispatcher for handling different event types ([#651](https://www.github.com/googleapis/python-aiplatform/issues/651)) ([d20b520](https://www.github.com/googleapis/python-aiplatform/commit/d20b520ea936a6554a24099beb0e044f237ff741)), closes [#519](https://www.github.com/googleapis/python-aiplatform/issues/519)
+
+
+### Documentation
+
+* Add code sample for Pipelines ([#684](https://www.github.com/googleapis/python-aiplatform/issues/684)) ([4f0c18e](https://www.github.com/googleapis/python-aiplatform/commit/4f0c18e8989cf353019876a73aa57457332e88fb))
+
+## [1.4.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.3.0...v1.4.0) (2021-08-30)
+
+
+### Features
+
+* add filter and timestamp splits ([#627](https://www.github.com/googleapis/python-aiplatform/issues/627)) ([1a13577](https://www.github.com/googleapis/python-aiplatform/commit/1a135775966c8a2303ded529eba514dcf9db7205))
+* add labels to all resource creation apis ([#601](https://www.github.com/googleapis/python-aiplatform/issues/601)) ([4e7666a](https://www.github.com/googleapis/python-aiplatform/commit/4e7666a30b4472698ed980d9d746ba85ad4142d8))
+* add PipelineJob.list ([a58ea82](https://www.github.com/googleapis/python-aiplatform/commit/a58ea826c575b9b0c8cb69e47fc2f07a98bb285b))
+* add support for export_evaluated_data_items_config in AutoMLTab⊠([#583](https://www.github.com/googleapis/python-aiplatform/issues/583)) ([2a6b0a3](https://www.github.com/googleapis/python-aiplatform/commit/2a6b0a369296698f79d75e93007e4c7319f3523c))
+* add util functions to get URLs for Tensorboard web app. ([#635](https://www.github.com/googleapis/python-aiplatform/issues/635)) ([8d88c00](https://www.github.com/googleapis/python-aiplatform/commit/8d88c006c5586b28d340448382a9292543448fd6))
+* Add wait_for_resource_creation to BatchPredictionJob and unblock async creation when model is pending creation. ([#660](https://www.github.com/googleapis/python-aiplatform/issues/660)) ([db580ad](https://www.github.com/googleapis/python-aiplatform/commit/db580ad43e97e0d877c29c0e8c077c37dee33ff3))
+* Added the VertexAiResourceNoun.to_dict() method ([#588](https://www.github.com/googleapis/python-aiplatform/issues/588)) ([b478075](https://www.github.com/googleapis/python-aiplatform/commit/b478075efb05553760514256fee9a63126a9916f))
+* expose base_output_dir for custom job ([#586](https://www.github.com/googleapis/python-aiplatform/issues/586)) ([2f138d1](https://www.github.com/googleapis/python-aiplatform/commit/2f138d1dfe4959d1b5f53a9dfef90a18de9908ec))
+* expose boot disk type and size for CustomTrainingJob, CustomPythonPackageTrainingJob, and CustomContainerTrainingJob ([#602](https://www.github.com/googleapis/python-aiplatform/issues/602)) ([355ea24](https://www.github.com/googleapis/python-aiplatform/commit/355ea24c6dd9b061ae0933df4dd07dd5b8c2232b))
+* split GAPIC samples by service ([#599](https://www.github.com/googleapis/python-aiplatform/issues/599)) ([5f15b4f](https://www.github.com/googleapis/python-aiplatform/commit/5f15b4f9a4bad2c9447747a8bdebaa99eab00b75))
+
+
+### Bug Fixes
+
+* Fixed bug in TabularDataset.column_names ([#590](https://www.github.com/googleapis/python-aiplatform/issues/590)) ([0fbcd59](https://www.github.com/googleapis/python-aiplatform/commit/0fbcd592cd7e9c4b0a131d777fa84e592a43a21c))
+* pipeline none values ([#649](https://www.github.com/googleapis/python-aiplatform/issues/649)) ([2f89343](https://www.github.com/googleapis/python-aiplatform/commit/2f89343adbd69610fc5cacc7121119fc7279186e))
+* Populate service_account and network in PipelineJob instead of pipeline_spec ([#658](https://www.github.com/googleapis/python-aiplatform/issues/658)) ([8fde2ce](https://www.github.com/googleapis/python-aiplatform/commit/8fde2ce4441139784bc0fdd62c88d4b833018765))
+* re-remove extra TB dependencies introduced due to merge conflict ([#593](https://www.github.com/googleapis/python-aiplatform/issues/593)) ([433b94a](https://www.github.com/googleapis/python-aiplatform/commit/433b94a78004de6d3a4726317d8bac32c358ace8))
+* Update BatchPredictionJob.iter_outputs() and BQ docstrings ([#631](https://www.github.com/googleapis/python-aiplatform/issues/631)) ([28f32fd](https://www.github.com/googleapis/python-aiplatform/commit/28f32fd11470ad86d2f103346b3e6be8f1adc2d8))
+
+## [1.3.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.2.0...v1.3.0) (2021-07-30)
+
+
+### Features
+
+* add get method for PipelineJob ([#561](https://www.github.com/googleapis/python-aiplatform/issues/561)) ([fe5e6e4](https://www.github.com/googleapis/python-aiplatform/commit/fe5e6e4576a6c8c73549effae99bced709e29402))
+* add Samples section to CONTRIBUTING.rst ([#558](https://www.github.com/googleapis/python-aiplatform/issues/558)) ([d35c466](https://www.github.com/googleapis/python-aiplatform/commit/d35c466e19ac9fa43b5668ce18520090b5e4edd9))
+* add tensorboard resource management ([#539](https://www.github.com/googleapis/python-aiplatform/issues/539)) ([6f8d3d1](https://www.github.com/googleapis/python-aiplatform/commit/6f8d3d1ed89f0aa6f2f0418ae752185104196c63))
+* add tf1 metadata builder ([#526](https://www.github.com/googleapis/python-aiplatform/issues/526)) ([918998c](https://www.github.com/googleapis/python-aiplatform/commit/918998c0bdc25b6a39d359a34f892dac1ca4efac))
+* add wait for creation and more informative exception when properties are not available ([#566](https://www.github.com/googleapis/python-aiplatform/issues/566)) ([e346117](https://www.github.com/googleapis/python-aiplatform/commit/e346117d5453358a32a1d6e584613ace5c2251d9))
+* Adds a new API method FindMostStableBuild ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds attribution_score_drift_threshold field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds attribution_score_skew_thresholds field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds BigQuery output table field to batch prediction job output config ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds CustomJob.enable_web_access field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds CustomJob.web_access_uris field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds Endpoint.network, Endpoint.private_endpoints fields and PrivateEndpoints message ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds Execution.State constants: CACHED and CANCELLED ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds Feature Store features ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds fields to Study message ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds IndexEndpoint.private_ip_ranges field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds IndexEndpointService.deployed_index_id field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds MetadataService.DeleteArtifact and DeleteExecution methods ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds ModelMonitoringObjectConfig.explanation_config field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* Adds ModelMonitoringObjectConfig.ExplanationConfig message field ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+* column specs for tabular transformation ([#466](https://www.github.com/googleapis/python-aiplatform/issues/466)) ([71d0bd4](https://www.github.com/googleapis/python-aiplatform/commit/71d0bd4615b436eaa3ec3eade4445934552f1cb3))
+* enable_caching in PipelineJob to compile time settings ([#557](https://www.github.com/googleapis/python-aiplatform/issues/557)) ([c9da662](https://www.github.com/googleapis/python-aiplatform/commit/c9da662ec24709622bcc4a9e85d1938bead91923))
+* Removes breaking change from v1 version of AI Platform protos ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+
+
+### Bug Fixes
+
+* change default replica count to 1 for custom training job classes ([#579](https://www.github.com/googleapis/python-aiplatform/issues/579)) ([c24251f](https://www.github.com/googleapis/python-aiplatform/commit/c24251fdd230e73c2aadb4369266b78979a31015))
+* create pipeline job with user-specified job id ([#567](https://www.github.com/googleapis/python-aiplatform/issues/567)) ([df68ec3](https://www.github.com/googleapis/python-aiplatform/commit/df68ec3441eeb7670531f50aaed00df6f7e2a1a3))
+* **deps:** pin 'google-{api,cloud}-core', 'google-auth' to allow 2.x versions ([#556](https://www.github.com/googleapis/python-aiplatform/issues/556)) ([5d79795](https://www.github.com/googleapis/python-aiplatform/commit/5d797956737f2d0d4afa4d28fe1fa2f835992991))
+* enable self signed jwt for grpc ([#570](https://www.github.com/googleapis/python-aiplatform/issues/570)) ([6a99b12](https://www.github.com/googleapis/python-aiplatform/commit/6a99b125922b8fca7c997150b81b6925376e9d1d))
+
+
+### Documentation
+
+* fix spelling ([#565](https://www.github.com/googleapis/python-aiplatform/issues/565)) ([fe5c702](https://www.github.com/googleapis/python-aiplatform/commit/fe5c7020040fb0b3b558643b8bc3e12e76f4055f))
+
+## [1.2.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.1.1...v1.2.0) (2021-07-14)
+
+
+### Features
+
+* Add additional_experiments field to AutoMlTablesInputs ([#540](https://www.github.com/googleapis/python-aiplatform/issues/540)) ([96ee726](https://www.github.com/googleapis/python-aiplatform/commit/96ee7261d5c3ffac5598c618b7c7499fad34ab12))
+* add always_use_jwt_access ([#498](https://www.github.com/googleapis/python-aiplatform/issues/498)) ([6df4866](https://www.github.com/googleapis/python-aiplatform/commit/6df48663286db10b1b88f947fc5873a18084cf37))
+* add explain get_metadata function for tf2. ([#507](https://www.github.com/googleapis/python-aiplatform/issues/507)) ([f6f9a97](https://www.github.com/googleapis/python-aiplatform/commit/f6f9a97bb178d9859b8d43166a43792d88e57710))
+* Add structure for XAI explain (from XAI SDK) ([#502](https://www.github.com/googleapis/python-aiplatform/issues/502)) ([cb9ef11](https://www.github.com/googleapis/python-aiplatform/commit/cb9ef1115e58c230f3d009397a6e6a27fd376bed))
+* Add two new ModelType constants for Video Action Recognition training jobs ([96ee726](https://www.github.com/googleapis/python-aiplatform/commit/96ee7261d5c3ffac5598c618b7c7499fad34ab12))
+* Adds AcceleratorType.NVIDIA_TESLA_A100 constant ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds additional_experiments field to AutoMlForecastingInputs ([8077b3d](https://www.github.com/googleapis/python-aiplatform/commit/8077b3d728b6e168c8aad41291dd90144ab75633))
+* Adds additional_experiments field to AutoMlTablesInputs ([#544](https://www.github.com/googleapis/python-aiplatform/issues/544)) ([8077b3d](https://www.github.com/googleapis/python-aiplatform/commit/8077b3d728b6e168c8aad41291dd90144ab75633))
+* Adds AutoscalingMetricSpec message ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds BigQuery output table field to batch prediction job output config ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds fields to Study message ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds JobState.JOB_STATE_EXPIRED constant ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds PipelineService methods for Create, Get, List, Delete, Cancel ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+* Adds two new ModelType constants for Video Action Recognition training jobs ([8077b3d](https://www.github.com/googleapis/python-aiplatform/commit/8077b3d728b6e168c8aad41291dd90144ab75633))
+* Removes AcceleratorType.TPU_V2 and TPU_V3 constants ([#543](https://www.github.com/googleapis/python-aiplatform/issues/543)) ([f3a3d03](https://www.github.com/googleapis/python-aiplatform/commit/f3a3d03c8509dc49c24139155a572dacbe954f66))
+
+
+### Bug Fixes
+
+* Handle nested fields from BigQuery source when getting default column_names ([#522](https://www.github.com/googleapis/python-aiplatform/issues/522)) ([3fc1d44](https://www.github.com/googleapis/python-aiplatform/commit/3fc1d44ac0acbb4f58088e7eeb16d85818af1125))
+* log pipeline completion and raise pipeline failures ([#523](https://www.github.com/googleapis/python-aiplatform/issues/523)) ([2508fe9](https://www.github.com/googleapis/python-aiplatform/commit/2508fe9d8a75ac8b05f06a78589c657313bd1d3d))
+* making the uploader depend on tensorflow-proper ([#499](https://www.github.com/googleapis/python-aiplatform/issues/499)) ([b95e040](https://www.github.com/googleapis/python-aiplatform/commit/b95e0406566879e8f71cefda72b41dc6fe4e578f))
+* Set prediction client when listing Endpoints ([#512](https://www.github.com/googleapis/python-aiplatform/issues/512)) ([95639ee](https://www.github.com/googleapis/python-aiplatform/commit/95639ee1c2c9cb66624265383d4d27bed3ff7dbd))
+
+## [1.1.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.1.0...v1.1.1) (2021-06-22)
+
+
+### Features
+
+* add cancel method to pipeline client ([#488](https://www.github.com/googleapis/python-aiplatform/issues/488)) ([3b19fff](https://www.github.com/googleapis/python-aiplatform/commit/3b19fff399b85c92e661eb83a48a4c6636423518))
+
+
+### Bug Fixes
+
+* check if training_task_metadata is populated before logging backingCustomJob ([#494](https://www.github.com/googleapis/python-aiplatform/issues/494)) ([2e627f8](https://www.github.com/googleapis/python-aiplatform/commit/2e627f876e1d7dd03e5d6bd2e81e6234e361a9df))
+
+
+### Documentation
+
+* omit mention of Python 2.7 in 'CONTRIBUTING.rst' ([#1127](https://www.github.com/googleapis/python-aiplatform/issues/1127)) ([#489](https://www.github.com/googleapis/python-aiplatform/issues/489)) ([cbc47f8](https://www.github.com/googleapis/python-aiplatform/commit/cbc47f862f291b00b85718498571e0c737cb26a6))
+
+
+### Miscellaneous Chores
+
+* release 1.1.1 ([1a38ce2](https://www.github.com/googleapis/python-aiplatform/commit/1a38ce2f9879e1c42c0c6b25b72bd4836e3a1f73))
+
+## [1.1.0](https://www.github.com/googleapis/python-aiplatform/compare/v1.0.1...v1.1.0) (2021-06-17)
+
+
+### Features
+
+* add aiplatform API Vizier service ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* add featurestore, index, metadata, monitoring, pipeline, and tensorboard services to aiplatform v1beta1 ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* add invalid_row_count to ImportFeatureValuesResponse and ImportFeatureValuesOperationMetadata ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* add pipeline client init and run to vertex AI ([1f1226f](https://www.github.com/googleapis/python-aiplatform/commit/1f1226fd8c745a7cd86c299fa0cfc2291947f3e7))
+* add tensorboard support for CustomTrainingJob, CustomContainerTrainingJob, CustomPythonPackageTrainingJob ([#462](https://www.github.com/googleapis/python-aiplatform/issues/462)) ([8cfd611](https://www.github.com/googleapis/python-aiplatform/commit/8cfd61179af06232173b91b4d9fd633028823624))
+* adds enhanced protos for time series forecasting ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* adds enhanced protos for time series forecasting ([#374](https://www.github.com/googleapis/python-aiplatform/issues/374)) ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* allow the prediction endpoint to be overridden ([#461](https://www.github.com/googleapis/python-aiplatform/issues/461)) ([c2cf612](https://www.github.com/googleapis/python-aiplatform/commit/c2cf61288326cad28ab474064b887687bc649d76))
+* AutoMlImageSegmentationInputs.ModelType adds MOBILE_TF_LOW_LATENCY constant ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* AutoMlVideoClassificationInputs.ModelType adds MOBILE_JETSON_VERSATILE_1 constant ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* Expose additional attributes into Vertex SDK to close gap with GAPIC ([#477](https://www.github.com/googleapis/python-aiplatform/issues/477)) ([572a27c](https://www.github.com/googleapis/python-aiplatform/commit/572a27c7929e5686b61950e09e17134564987d50))
+* ImageSegmentationPredictionResult.category_mask field changed to string data type ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* remove unsupported accelerator types ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* removes forecasting (time_series_forecasting proto) from public v1beta1 protos ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* removes unused protos from schema/ folders: schema/io_format.proto, schema/saved_query_metadata.proto ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* support additional_experiments for AutoML Tables and AutoML Forecasting ([#428](https://www.github.com/googleapis/python-aiplatform/issues/428)) ([b4211f2](https://www.github.com/googleapis/python-aiplatform/commit/b4211f2f60aead88107c08a18d30b0800b019593))
+* support self-signed JWT flow for service accounts ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+
+
+### Bug Fixes
+
+* add async client to %name_%version/init.py ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* add target_column docstring ([#473](https://www.github.com/googleapis/python-aiplatform/issues/473)) ([c0543cd](https://www.github.com/googleapis/python-aiplatform/commit/c0543cdd1e9ba0efd18d7d1a442906938fc6db9a))
+* configuring timeouts for aiplatform v1 methods ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* Enable MetadataStore to use credentials when aiplatfrom.init passed experiment and credentials. ([#460](https://www.github.com/googleapis/python-aiplatform/issues/460)) ([e7bf0d8](https://www.github.com/googleapis/python-aiplatform/commit/e7bf0d83d8bb0849a9bce886c958d13f5cbe5fab))
+* exclude docs and tests from package ([#481](https://www.github.com/googleapis/python-aiplatform/issues/481)) ([b209904](https://www.github.com/googleapis/python-aiplatform/commit/b2099049484f66f4348ddd4448c676feecb0b46e))
+* pass credentials to BQ and GCS clients ([#469](https://www.github.com/googleapis/python-aiplatform/issues/469)) ([481d172](https://www.github.com/googleapis/python-aiplatform/commit/481d172542ffd80e18f4fab5b01945be17d5e18c))
+* remove display_name from FeatureStore ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* Remove URI attribute from Endpoint sample ([#478](https://www.github.com/googleapis/python-aiplatform/issues/478)) ([e3cbdd8](https://www.github.com/googleapis/python-aiplatform/commit/e3cbdd8322c854f526c8564f8bb61fb6525598d7))
+
+
+### Documentation
+
+* changes product name to Vertex AI ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* correct link to fieldmask ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+* removes tinyurl links ([fdc968f](https://www.github.com/googleapis/python-aiplatform/commit/fdc968f49e89a5c7ca14692080c0ae7e8b6e0865))
+
+## [1.0.1](https://www.github.com/googleapis/python-aiplatform/compare/v1.0.0...v1.0.1) (2021-05-21)
### Bug Fixes
@@ -70,7 +676,7 @@
* env formatiing ([#379](https://www.github.com/googleapis/python-aiplatform/issues/379)) ([6bc4c61](https://www.github.com/googleapis/python-aiplatform/commit/6bc4c612d5471911f82ee5ada9fb3a9307ee836f))
* remove Optional type hint on deploy ([#345](https://www.github.com/googleapis/python-aiplatform/issues/345)) ([79b0ab1](https://www.github.com/googleapis/python-aiplatform/commit/79b0ab13e6d08a12ac0a0971a8001e9ddb8baf56))
-### [0.7.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.7.0...v0.7.1) (2021-04-14)
+## [0.7.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.7.0...v0.7.1) (2021-04-14)
### Bug Fixes
@@ -108,7 +714,7 @@
* skip create data labeling job sample tests ([#254](https://www.github.com/googleapis/python-aiplatform/issues/254)) ([116a29b](https://www.github.com/googleapis/python-aiplatform/commit/116a29b1efcebb15bad14c3c36d3591c09ef10be))
-### [0.5.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.0...v0.5.1) (2021-03-01)
+## [0.5.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.0...v0.5.1) (2021-03-01)
### Bug Fixes
@@ -172,7 +778,7 @@
* update readme ([#81](https://www.github.com/googleapis/python-aiplatform/issues/81)) ([19dc31a](https://www.github.com/googleapis/python-aiplatform/commit/19dc31a7e63ec112e9d0dc72e22db04910137d07))
-### [0.3.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.3.0...v0.3.1) (2020-11-13)
+## [0.3.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.3.0...v0.3.1) (2020-11-13)
### Features
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index f865e3769d..bdf18d174f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -22,7 +22,7 @@ In order to add a feature:
documentation.
- The feature must work fully on the following CPython versions:
- 3.6, 3.7, 3.8 and 3.9 on both UNIX and Windows.
+ 3.7, 3.8 and 3.9 on both UNIX and Windows.
- The feature must not add unnecessary dependencies (where
"unnecessary" is of course subjective, but new dependencies should
@@ -50,9 +50,9 @@ You'll have to create a development environment using a Git checkout:
# Configure remotes such that you can pull changes from the googleapis/python-aiplatform
# repository into your local repository.
$ git remote add upstream git@github.com:googleapis/python-aiplatform.git
- # fetch and merge changes from upstream into master
+ # fetch and merge changes from upstream into main
$ git fetch upstream
- $ git merge upstream/master
+ $ git merge upstream/main
Now your local repo is set up such that you will push changes to your GitHub
repo, from which you can submit a pull request.
@@ -68,15 +68,12 @@ Using ``nox``
We use `nox `__ to instrument our tests.
- To test your changes, run unit tests with ``nox``::
+ $ nox -s unit
- $ nox -s unit-2.7
- $ nox -s unit-3.8
- $ ...
+- To run a single unit test::
-- Args to pytest can be passed through the nox command separated by a `--`. For
- example, to run a single test::
+ $ nox -s unit-3.9 -- -k
- $ nox -s unit-3.8 -- -k
.. note::
@@ -113,12 +110,12 @@ Coding Style
variables::
export GOOGLE_CLOUD_TESTING_REMOTE="upstream"
- export GOOGLE_CLOUD_TESTING_BRANCH="master"
+ export GOOGLE_CLOUD_TESTING_BRANCH="main"
By doing this, you are specifying the location of the most up-to-date
- version of ``python-aiplatform``. The the suggested remote name ``upstream``
- should point to the official ``googleapis`` checkout and the
- the branch should be the main branch on that remote (``master``).
+ version of ``python-aiplatform``. The
+ remote name ``upstream`` should point to the official ``googleapis``
+ checkout and the branch should be the default branch on that remote (``main``).
- This repository contains configuration for the
`pre-commit `__ tool, which automates checking
@@ -143,8 +140,7 @@ Running System Tests
- To run system tests, you can execute::
# Run all system tests
- $ nox -s system-3.8
- $ nox -s system-2.7
+ $ nox -s system
# Run a single system test
$ nox -s system-3.8 -- -k
@@ -152,9 +148,8 @@ Running System Tests
.. note::
- System tests are only configured to run under Python 2.7 and
- Python 3.8. For expediency, we do not run them in older versions
- of Python 3.
+ System tests are only configured to run under Python 3.8.
+ For expediency, we do not run them in older versions of Python 3.
This alone will not run the tests. You'll need to change some local
auth settings and change some configuration in your project to
@@ -182,6 +177,30 @@ Build the docs via:
$ nox -s docs
+*************************
+Samples and code snippets
+*************************
+
+Code samples and snippets live in the `samples/` catalogue. Feel free to
+provide more examples, but make sure to write tests for those examples.
+Each folder containing example code requires its own `noxfile.py` script
+which automates testing. If you decide to create a new folder, you can
+base it on the `samples/snippets` folder (providing `noxfile.py` and
+the requirements files).
+
+The tests will run against a real Google Cloud Project, so you should
+configure them just like the System Tests.
+
+- To run sample tests, you can execute::
+
+ # Run all tests in a folder
+ $ cd samples/snippets
+ $ nox -s py-3.8
+
+ # Run a single sample test
+ $ cd samples/snippets
+ $ nox -s py-3.8 -- -k
+
********************************************
Note About ``README`` as it pertains to PyPI
********************************************
@@ -190,7 +209,7 @@ The `description on PyPI`_ for the project comes directly from the
``README``. Due to the reStructuredText (``rst``) parser used by
PyPI, relative links which will work on GitHub (e.g. ``CONTRIBUTING.rst``
instead of
-``https://github.com/googleapis/python-aiplatform/blob/master/CONTRIBUTING.rst``)
+``https://github.com/googleapis/python-aiplatform/blob/main/CONTRIBUTING.rst``)
may cause problems creating links or rendering the description.
.. _description on PyPI: https://pypi.org/project/google-cloud-aiplatform
@@ -202,12 +221,10 @@ Supported Python Versions
We support:
-- `Python 3.6`_
- `Python 3.7`_
- `Python 3.8`_
- `Python 3.9`_
-.. _Python 3.6: https://docs.python.org/3.6/
.. _Python 3.7: https://docs.python.org/3.7/
.. _Python 3.8: https://docs.python.org/3.8/
.. _Python 3.9: https://docs.python.org/3.9/
@@ -215,11 +232,11 @@ We support:
Supported versions can be found in our ``noxfile.py`` `config`_.
-.. _config: https://github.com/googleapis/python-aiplatform/blob/master/noxfile.py
+.. _config: https://github.com/googleapis/python-aiplatform/blob/main/noxfile.py
-We also explicitly decided to support Python 3 beginning with version
-3.6. Reasons for this include:
+We also explicitly decided to support Python 3 beginning with version 3.7.
+Reasons for this include:
- Encouraging use of newest versions of Python 3
- Taking the lead of `prominent`_ open-source `projects`_
diff --git a/README.rst b/README.rst
index 57ead60fea..0b959ee04f 100644
--- a/README.rst
+++ b/README.rst
@@ -1,7 +1,7 @@
-Vertex SDK for Python
+Vertex AI SDK for Python
=================================================
-|GA| |pypi| |versions|
+|GA| |pypi| |versions| |unit-tests| |system-tests| |sample-tests|
`Vertex AI`_: Google Vertex AI is an integrated suite of machine learning tools and services for building and using ML models with AutoML or custom code. It offers both novices and experts the best workbench for the entire machine learning development lifecycle.
@@ -10,13 +10,19 @@ Vertex SDK for Python
- `Product Documentation`_
.. |GA| image:: https://img.shields.io/badge/support-ga-gold.svg
- :target: https://github.com/googleapis/google-cloud-python/blob/master/README.rst#general-availability
+ :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#general-availability
.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-aiplatform.svg
:target: https://pypi.org/project/google-cloud-aiplatform/
.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-aiplatform.svg
:target: https://pypi.org/project/google-cloud-aiplatform/
+.. |unit-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.svg
+ :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.html
+.. |system-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.svg
+ :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.html
+.. |sample-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.svg
+ :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.html
.. _Vertex AI: https://cloud.google.com/vertex-ai/docs
-.. _Client Library Documentation: https://googleapis.dev/python/aiplatform/latest
+.. _Client Library Documentation: https://cloud.google.com/python/docs/reference/aiplatform/latest
.. _Product Documentation: https://cloud.google.com/vertex-ai/docs
Quick Start
@@ -70,11 +76,34 @@ Windows
\Scripts\pip.exe install google-cloud-aiplatform
+Supported Python Versions
+^^^^^^^^^^^^^^^^^^^^^^^^^
+Python >= 3.7
+
+Deprecated Python Versions
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+Python == 3.6.
+
+The last version of this library compatible with Python 3.6 is google-cloud-aiplatform==1.12.1.
+
Overview
~~~~~~~~
-This section provides a brief overview of the Vertex SDK for Python. You can also reference the notebooks in `vertex-ai-samples`_ for examples.
+This section provides a brief overview of the Vertex AI SDK for Python. You can also reference the notebooks in `vertex-ai-samples`_ for examples.
+
+.. _vertex-ai-samples: https://github.com/GoogleCloudPlatform/vertex-ai-samples/tree/main/notebooks/community/sdk
+
+All publicly available SDK features can be found in the :code:`google/cloud/aiplatform` directory.
+Under the hood, Vertex SDK builds on top of GAPIC, which stands for Google API CodeGen.
+The GAPIC library code sits in :code:`google/cloud/aiplatform_v1` and :code:`google/cloud/aiplatform_v1beta1`,
+and it is auto-generated from Google's service proto files.
-.. _vertex-ai-samples: https://github.com/GoogleCloudPlatform/ai-platform-samples/tree/master/ai-platform-unified/notebooks/unofficial/sdk
+For most developers' programmatic needs, they can follow these steps to figure out which libraries to import:
+
+1. Look through :code:`google/cloud/aiplatform` first -- Vertex SDK's APIs will almost always be easier to use and more concise comparing with GAPIC
+2. If the feature that you are looking for cannot be found there, look through :code:`aiplatform_v1` to see if it's available in GAPIC
+3. If it is still in beta phase, it will be available in :code:`aiplatform_v1beta1`
+
+If none of the above scenarios could help you find the right tools for your task, please feel free to open a github issue and send us a feature request.
Importing
^^^^^^^^^
@@ -100,12 +129,12 @@ Initialize the SDK to store common configurations that you use with the SDK.
# defaults to us-central1
location='us-central1',
- # Googlge Cloud Stoage bucket in same region as location
+ # Google Cloud Storage bucket in same region as location
# used to stage artifacts
staging_bucket='gs://my_staging_bucket',
# custom google.auth.credentials.Credentials
- # environment default creds used if not set
+ # environment default credentials used if not set
credentials=my_credentials,
# customer managed encryption key resource name
@@ -117,7 +146,7 @@ Initialize the SDK to store common configurations that you use with the SDK.
experiment='my-experiment',
# description of the experiment above
- experiment_description='my experiment decsription'
+ experiment_description='my experiment description'
)
Datasets
@@ -149,7 +178,7 @@ You can also create and import a dataset in separate steps:
To get a previously created Dataset:
.. code-block:: Python
-
+
dataset = aiplatform.ImageDataset('projects/my-project/location/us-central1/datasets/{DATASET_ID}')
Vertex AI supports a variety of dataset schemas. References to these schemas are available under the
@@ -160,7 +189,7 @@ Vertex AI supports a variety of dataset schemas. References to these schemas are
Training
^^^^^^^^
-The Vertex SDK for Python allows you train Custom and AutoML Models.
+The Vertex AI SDK for Python allows you train Custom and AutoML Models.
You can train custom models using a custom Python script, custom Python package, or container.
@@ -173,7 +202,7 @@ It must read datasets from the environment variables populated by the training s
.. code-block:: Python
- os.environ['AIP_DATA_FORMAT'] # provides format of data
+ os.environ['AIP_DATA_FORMAT'] # provides format of data
os.environ['AIP_TRAINING_DATA_URI'] # uri to training split
os.environ['AIP_VALIDATION_DATA_URI'] # uri to validation split
os.environ['AIP_TEST_DATA_URI'] # uri to test split
@@ -182,9 +211,9 @@ Please visit `Using a managed dataset in a custom training application`_ for a d
.. _Using a managed dataset in a custom training application: https://cloud.google.com/vertex-ai/docs/training/using-managed-datasets
-It must write the model artifact to the environment variable populated by the traing service:
+It must write the model artifact to the environment variable populated by the training service:
-.. code-block:: Python
+.. code-block:: Python
os.environ['AIP_MODEL_DIR']
@@ -195,9 +224,9 @@ It must write the model artifact to the environment variable populated by the tr
job = aiplatform.CustomTrainingJob(
display_name="my-training-job",
script_path="training_script.py",
- container_uri="gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest",
+ container_uri="us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-2:latest",
requirements=["gcsfs==0.7.1"],
- model_serving_container_image_uri="gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest",
+ model_serving_container_image_uri="us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
)
model = job.run(my_dataset,
@@ -211,7 +240,7 @@ In the code block above `my_dataset` is managed dataset created in the `Dataset`
AutoMLs
-------
-The Vertex SDK for Python supports AutoML tabular, image, text, video, and forecasting.
+The Vertex AI SDK for Python supports AutoML tabular, image, text, video, and forecasting.
To train an AutoML tabular model:
@@ -239,6 +268,26 @@ To train an AutoML tabular model:
Models
------
+To get a model:
+
+
+.. code-block:: Python
+
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
+
+
+
+To upload a model:
+
+.. code-block:: Python
+
+ model = aiplatform.Model.upload(
+ display_name='my-model',
+ artifact_uri="gs://python/to/my/model/dir",
+ serving_container_image_uri="us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
+ )
+
+
To deploy a model:
@@ -253,49 +302,100 @@ To deploy a model:
accelerator_count=1)
-To upload a model:
+Please visit `Importing models to Vertex AI`_ for a detailed overview:
+
+.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model
+
+Model Evaluation
+----------------
+
+The Vertex AI SDK for Python currently supports getting model evaluation metrics for all AutoML models.
+
+To list all model evaluations for a model:
.. code-block:: Python
- model = aiplatform.Model.upload(
- display_name='my-model',
- artifact_uri="gs://python/to/my/model/dir",
- serving_container_image_uri="gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest",
- )
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
-To get a model:
+ evaluations = model.list_model_evaluations()
+
+
+To get the model evaluation resource for a given model:
.. code-block:: Python
model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
-Please visit `Importing models to Vertex AI`_ for a detailed overview:
+ # returns the first evaluation with no arguments, you can also pass the evaluation ID
+ evaluation = model.get_model_evaluation()
-.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model
+ eval_metrics = evaluation.metrics
-Endpoints
----------
+You can also create a reference to your model evaluation directly by passing in the resource name of the model evaluation:
-To get predictions from endpoints:
+.. code-block:: Python
+
+ evaluation = aiplatform.ModelEvaluation(
+ evaluation_name='/projects/my-project/locations/us-central1/models/{MODEL_ID}/evaluations/{EVALUATION_ID}')
+
+Alternatively, you can create a reference to your evaluation by passing in the model and evaluation IDs:
.. code-block:: Python
- endpoint.predict(instances=[[6.7, 3.1, 4.7, 1.5], [4.6, 3.1, 1.5, 0.2]])
+ evaluation = aiplatform.ModelEvaluation(
+ evaluation_name={EVALUATION_ID},
+ model_id={MODEL_ID})
+
+
+Batch Prediction
+----------------
+
+To create a batch prediction job:
+
+.. code-block:: Python
+
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
+
+ batch_prediction_job = model.batch_predict(
+ job_display_name='my-batch-prediction-job',
+ instances_format='csv'
+ machine_type='n1-standard-4',
+ gcs_source=['gs://path/to/my/file.csv']
+ gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
+ )
+
+You can also create a batch prediction job asynchronously by including the `sync=False` argument:
+
+.. code-block:: Python
+
+ batch_prediction_job = model.batch_predict(..., sync=False)
+
+ # wait for resource to be created
+ batch_prediction_job.wait_for_resource_creation()
+
+ # get the state
+ batch_prediction_job.state
+ # block until job is complete
+ batch_prediction_job.wait()
-To create an endpoint
+
+Endpoints
+---------
+
+To create an endpoint:
.. code-block:: Python
- endpoint = endpoint.create(display_name='my-endpoint')
+ endpoint = aiplatform.Endpoint.create(display_name='my-endpoint')
To deploy a model to a created endpoint:
.. code-block:: Python
model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
-
+
endpoint.deploy(model,
min_replica_count=1,
max_replica_count=5
@@ -303,6 +403,12 @@ To deploy a model to a created endpoint:
accelerator_type='NVIDIA_TESLA_K80',
accelerator_count=1)
+To get predictions from endpoints:
+
+.. code-block:: Python
+
+ endpoint.predict(instances=[[6.7, 3.1, 4.7, 1.5], [4.6, 3.1, 1.5, 0.2]])
+
To undeploy models from an endpoint:
.. code-block:: Python
@@ -312,10 +418,137 @@ To undeploy models from an endpoint:
To delete an endpoint:
.. code-block:: Python
-
+
endpoint.delete()
+Pipelines
+---------
+
+To create a Vertex AI Pipeline run and monitor until completion:
+
+.. code-block:: Python
+
+ # Instantiate PipelineJob object
+ pl = PipelineJob(
+ display_name="My first pipeline",
+
+ # Whether or not to enable caching
+ # True = always cache pipeline step result
+ # False = never cache pipeline step result
+ # None = defer to cache option for each pipeline component in the pipeline definition
+ enable_caching=False,
+
+ # Local or GCS path to a compiled pipeline definition
+ template_path="pipeline.json",
+
+ # Dictionary containing input parameters for your pipeline
+ parameter_values=parameter_values,
+
+ # GCS path to act as the pipeline root
+ pipeline_root=pipeline_root,
+ )
+
+ # Execute pipeline in Vertex AI and monitor until completion
+ pl.run(
+ # Email address of service account to use for the pipeline run
+ # You must have iam.serviceAccounts.actAs permission on the service account to use it
+ service_account=service_account,
+
+ # Whether this function call should be synchronous (wait for pipeline run to finish before terminating)
+ # or asynchronous (return immediately)
+ sync=True
+ )
+
+To create a Vertex AI Pipeline without monitoring until completion, use `submit` instead of `run`:
+
+.. code-block:: Python
+
+ # Instantiate PipelineJob object
+ pl = PipelineJob(
+ display_name="My first pipeline",
+
+ # Whether or not to enable caching
+ # True = always cache pipeline step result
+ # False = never cache pipeline step result
+ # None = defer to cache option for each pipeline component in the pipeline definition
+ enable_caching=False,
+
+ # Local or GCS path to a compiled pipeline definition
+ template_path="pipeline.json",
+
+ # Dictionary containing input parameters for your pipeline
+ parameter_values=parameter_values,
+
+ # GCS path to act as the pipeline root
+ pipeline_root=pipeline_root,
+ )
+
+ # Submit the Pipeline to Vertex AI
+ pl.submit(
+ # Email address of service account to use for the pipeline run
+ # You must have iam.serviceAccounts.actAs permission on the service account to use it
+ service_account=service_account,
+ )
+
+
+Explainable AI: Get Metadata
+----------------------------
+
+To get metadata in dictionary format from TensorFlow 1 models:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
+
+ builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
+ 'gs://python/to/my/model/dir', tags=[tf.saved_model.tag_constants.SERVING]
+ )
+ generated_md = builder.get_metadata()
+
+To get metadata in dictionary format from TensorFlow 2 models:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
+
+ builder = saved_model_metadata_builder.SavedModelMetadataBuilder('gs://python/to/my/model/dir')
+ generated_md = builder.get_metadata()
+
+To use Explanation Metadata in endpoint deployment and model upload:
+
+.. code-block:: Python
+
+ explanation_metadata = builder.get_metadata_protobuf()
+
+ # To deploy a model to an endpoint with explanation
+ model.deploy(..., explanation_metadata=explanation_metadata)
+
+ # To deploy a model to a created endpoint with explanation
+ endpoint.deploy(..., explanation_metadata=explanation_metadata)
+
+ # To upload a model with explanation
+ aiplatform.Model.upload(..., explanation_metadata=explanation_metadata)
+
+
+Cloud Profiler
+----------------------------
+
+Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex AI Tensorboard.
+
+To start using the profiler with TensorFlow, update your training script to include the following:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.training_utils import cloud_profiler
+ ...
+ cloud_profiler.init()
+
+Next, run the job with with a Vertex AI TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview
+
+Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs.
+
+
Next Steps
~~~~~~~~~~
@@ -327,4 +560,4 @@ Next Steps
APIs that we cover.
.. _Vertex AI API Product documentation: https://cloud.google.com/vertex-ai/docs
-.. _README: https://github.com/googleapis/google-cloud-python/blob/master/README.rst
\ No newline at end of file
+.. _README: https://github.com/googleapis/google-cloud-python/blob/main/README.rst
diff --git a/docs/README.rst b/docs/README.rst
index 391fa89e8f..f1c894550c 100644
--- a/docs/README.rst
+++ b/docs/README.rst
@@ -1,27 +1,24 @@
-Python Client for Cloud AI Platform
+Vertex AI SDK for Python
=================================================
-|beta| |pypi| |versions|
+|GA| |pypi| |versions|
-`Cloud AI Platform`_: Cloud AI Platform is a suite of machine learning tools that enables
- developers to train high-quality models specific to their business needs.
- It offers both novices and experts the best workbench for machine learning
- development by leveraging Google's state-of-the-art transfer learning and
- Neural Architecture Search technology.
+`Vertex AI`_: Google Vertex AI is an integrated suite of machine learning tools and services for building and using ML models with AutoML or custom code. It offers both novices and experts the best workbench for the entire machine learning development lifecycle.
- `Client Library Documentation`_
- `Product Documentation`_
-.. |beta| image:: https://img.shields.io/badge/support-beta-orange.svg
- :target: https://github.com/googleapis/google-cloud-python/blob/master/README.rst#beta-support
+.. |GA| image:: https://img.shields.io/badge/support-ga-gold.svg
+ :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#general-availability
.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-aiplatform.svg
:target: https://pypi.org/project/google-cloud-aiplatform/
.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-aiplatform.svg
:target: https://pypi.org/project/google-cloud-aiplatform/
-.. _Cloud AI Platform: https://cloud.google.com/ai-platform/docs
+.. _Vertex AI: https://cloud.google.com/vertex-ai/docs
.. _Client Library Documentation: https://googleapis.dev/python/aiplatform/latest
.. _Product Documentation: https://cloud.google.com/ai-platform/docs
+
Quick Start
-----------
@@ -72,15 +69,403 @@ Windows
\Scripts\activate
\Scripts\pip.exe install google-cloud-aiplatform
+Overview
+~~~~~~~~
+This section provides a brief overview of the Vertex AI SDK for Python. You can also reference the notebooks in `vertex-ai-samples`_ for examples.
+
+.. _vertex-ai-samples: https://github.com/GoogleCloudPlatform/ai-platform-samples/tree/master/ai-platform-unified/notebooks/unofficial/sdk
+
+Importing
+^^^^^^^^^
+SDK functionality can be used from the root of the package:
+
+.. code-block:: Python
+
+ from google.cloud import aiplatform
+
+
+Initialization
+^^^^^^^^^^^^^^
+Initialize the SDK to store common configurations that you use with the SDK.
+
+.. code-block:: Python
+
+ aiplatform.init(
+ # your Google Cloud Project ID or number
+ # environment default used is not set
+ project='my-project',
+
+ # the Vertex AI region you will use
+ # defaults to us-central1
+ location='us-central1',
+
+ # Google Cloud Storage bucket in same region as location
+ # used to stage artifacts
+ staging_bucket='gs://my_staging_bucket',
+
+ # custom google.auth.credentials.Credentials
+ # environment default creds used if not set
+ credentials=my_credentials,
+
+ # customer managed encryption key resource name
+ # will be applied to all Vertex AI resources if set
+ encryption_spec_key_name=my_encryption_key_name,
+
+ # the name of the experiment to use to track
+ # logged metrics and parameters
+ experiment='my-experiment',
+
+ # description of the experiment above
+ experiment_description='my experiment decsription'
+ )
+
+Datasets
+^^^^^^^^
+Vertex AI provides managed tabular, text, image, and video datasets. In the SDK, datasets can be used downstream to
+train models.
+
+To create a tabular dataset:
+
+.. code-block:: Python
+
+ my_dataset = aiplatform.TabularDataset.create(
+ display_name="my-dataset", gcs_source=['gs://path/to/my/dataset.csv'])
+
+You can also create and import a dataset in separate steps:
+
+.. code-block:: Python
+
+ from google.cloud import aiplatform
+
+ my_dataset = aiplatform.TextDataset.create(
+ display_name="my-dataset")
+
+ my_dataset.import(
+ gcs_source=['gs://path/to/my/dataset.csv']
+ import_schema_uri=aiplatform.schema.dataset.ioformat.text.multi_label_classification
+ )
+
+To get a previously created Dataset:
+
+.. code-block:: Python
+
+ dataset = aiplatform.ImageDataset('projects/my-project/location/us-central1/datasets/{DATASET_ID}')
+
+Vertex AI supports a variety of dataset schemas. References to these schemas are available under the
+:code:`aiplatform.schema.dataset` namespace. For more information on the supported dataset schemas please refer to the
+`Preparing data docs`_.
+
+.. _Preparing data docs: https://cloud.google.com/ai-platform-unified/docs/datasets/prepare
+
+Training
+^^^^^^^^
+The Vertex AI SDK for Python allows you train Custom and AutoML Models.
+
+You can train custom models using a custom Python script, custom Python package, or container.
+
+**Preparing Your Custom Code**
+
+Vertex AI custom training enables you to train on Vertex AI datasets and produce Vertex AI models. To do so your
+script must adhere to the following contract:
+
+It must read datasets from the environment variables populated by the training service:
+
+.. code-block:: Python
+
+ os.environ['AIP_DATA_FORMAT'] # provides format of data
+ os.environ['AIP_TRAINING_DATA_URI'] # uri to training split
+ os.environ['AIP_VALIDATION_DATA_URI'] # uri to validation split
+ os.environ['AIP_TEST_DATA_URI'] # uri to test split
+
+Please visit `Using a managed dataset in a custom training application`_ for a detailed overview.
+
+.. _Using a managed dataset in a custom training application: https://cloud.google.com/vertex-ai/docs/training/using-managed-datasets
+
+It must write the model artifact to the environment variable populated by the traing service:
+
+.. code-block:: Python
+
+ os.environ['AIP_MODEL_DIR']
+
+**Running Training**
+
+.. code-block:: Python
+
+ job = aiplatform.CustomTrainingJob(
+ display_name="my-training-job",
+ script_path="training_script.py",
+ container_uri="gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest",
+ requirements=["gcsfs==0.7.1"],
+ model_serving_container_image_uri="gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest",
+ )
+
+ model = job.run(my_dataset,
+ replica_count=1,
+ machine_type="n1-standard-4",
+ accelerator_type='NVIDIA_TESLA_K80',
+ accelerator_count=1)
+
+In the code block above `my_dataset` is managed dataset created in the `Dataset` section above. The `model` variable is a managed Vertex AI model that can be deployed or exported.
+
+
+AutoMLs
+-------
+The Vertex AI SDK for Python supports AutoML tabular, image, text, video, and forecasting.
+
+To train an AutoML tabular model:
+
+.. code-block:: Python
+
+ dataset = aiplatform.TabularDataset('projects/my-project/location/us-central1/datasets/{DATASET_ID}')
+
+ job = aiplatform.AutoMLTabularTrainingJob(
+ display_name="train-automl",
+ optimization_prediction_type="regression",
+ optimization_objective="minimize-rmse",
+ )
+
+ model = job.run(
+ dataset=dataset,
+ target_column="target_column_name",
+ training_fraction_split=0.6,
+ validation_fraction_split=0.2,
+ test_fraction_split=0.2,
+ budget_milli_node_hours=1000,
+ model_display_name="my-automl-model",
+ disable_early_stopping=False,
+ )
+
+
+Models
+------
+
+To deploy a model:
+
+
+.. code-block:: Python
+
+ endpoint = model.deploy(machine_type="n1-standard-4",
+ min_replica_count=1,
+ max_replica_count=5
+ machine_type='n1-standard-4',
+ accelerator_type='NVIDIA_TESLA_K80',
+ accelerator_count=1)
+
+
+To upload a model:
+
+.. code-block:: Python
+
+ model = aiplatform.Model.upload(
+ display_name='my-model',
+ artifact_uri="gs://python/to/my/model/dir",
+ serving_container_image_uri="gcr.io/cloud-aiplatform/prediction/tf2-cpu.2-2:latest",
+ )
+
+To get a model:
+
+.. code-block:: Python
+
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
+
+Please visit `Importing models to Vertex AI`_ for a detailed overview:
+
+.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model
+
+
+Batch Prediction
+----------------
+
+To create a batch prediction job:
+
+.. code-block:: Python
+
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
+
+ batch_prediction_job = model.batch_predict(
+ job_display_name='my-batch-prediction-job',
+ instances_format='csv'
+ machine_type='n1-standard-4',
+ gcs_source=['gs://path/to/my/file.csv']
+ gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
+ )
+
+You can also create a batch prediction job asynchronously by including the `sync=False` argument:
+
+.. code-block:: Python
+
+ batch_prediction_job = model.batch_predict(..., sync=False)
+
+ # wait for resource to be created
+ batch_prediction_job.wait_for_resource_creation()
+
+ # get the state
+ batch_prediction_job.state
+
+ # block until job is complete
+ batch_prediction_job.wait()
+
+
+Endpoints
+---------
+
+To get predictions from endpoints:
+
+.. code-block:: Python
+
+ endpoint.predict(instances=[[6.7, 3.1, 4.7, 1.5], [4.6, 3.1, 1.5, 0.2]])
+
+
+To create an endpoint
+
+.. code-block:: Python
+
+ endpoint = endpoint.create(display_name='my-endpoint')
+
+To deploy a model to a created endpoint:
+
+.. code-block:: Python
+
+ model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')
+
+ endpoint.deploy(model,
+ min_replica_count=1,
+ max_replica_count=5
+ machine_type='n1-standard-4',
+ accelerator_type='NVIDIA_TESLA_K80',
+ accelerator_count=1)
+
+To undeploy models from an endpoint:
+
+.. code-block:: Python
+
+ endpoint.undeploy_all()
+
+To delete an endpoint:
+
+.. code-block:: Python
+
+ endpoint.delete()
+
+
+Pipelines
+---------
+
+To create a Vertex AI Pipeline run and monitor until completion:
+
+.. code-block:: Python
+
+ # Instantiate PipelineJob object
+ pl = PipelineJob(
+ display_name="My first pipeline",
+
+ # Whether or not to enable caching
+ # True = always cache pipeline step result
+ # False = never cache pipeline step result
+ # None = defer to cache option for each pipeline component in the pipeline definition
+ enable_caching=False,
+
+ # Local or GCS path to a compiled pipeline definition
+ template_path="pipeline.json",
+
+ # Dictionary containing input parameters for your pipeline
+ parameter_values=parameter_values,
+
+ # GCS path to act as the pipeline root
+ pipeline_root=pipeline_root,
+ )
+
+ # Execute pipeline in Vertex AI and monitor until completion
+ pl.run(
+ # Email address of service account to use for the pipeline run
+ # You must have iam.serviceAccounts.actAs permission on the service account to use it
+ service_account=service_account,
+
+ # Whether this function call should be synchronous (wait for pipeline run to finish before terminating)
+ # or asynchronous (return immediately)
+ sync=True
+ )
+
+To create a Vertex AI Pipeline without monitoring until completion, use `submit` instead of `run`:
+
+.. code-block:: Python
+
+ # Instantiate PipelineJob object
+ pl = PipelineJob(
+ display_name="My first pipeline",
+
+ # Whether or not to enable caching
+ # True = always cache pipeline step result
+ # False = never cache pipeline step result
+ # None = defer to cache option for each pipeline component in the pipeline definition
+ enable_caching=False,
+
+ # Local or GCS path to a compiled pipeline definition
+ template_path="pipeline.json",
+
+ # Dictionary containing input parameters for your pipeline
+ parameter_values=parameter_values,
+
+ # GCS path to act as the pipeline root
+ pipeline_root=pipeline_root,
+ )
+
+ # Submit the Pipeline to Vertex AI
+ pl.submit(
+ # Email address of service account to use for the pipeline run
+ # You must have iam.serviceAccounts.actAs permission on the service account to use it
+ service_account=service_account,
+ )
+
+
+Explainable AI: Get Metadata
+----------------------------
+
+To get metadata in dictionary format from TensorFlow 1 models:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
+
+ builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
+ 'gs://python/to/my/model/dir', tags=[tf.saved_model.tag_constants.SERVING]
+ )
+ generated_md = builder.get_metadata()
+
+To get metadata in dictionary format from TensorFlow 2 models:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
+
+ builder = saved_model_metadata_builder.SavedModelMetadataBuilder('gs://python/to/my/model/dir')
+ generated_md = builder.get_metadata()
+
+To use Explanation Metadata in endpoint deployment and model upload:
+
+.. code-block:: Python
+
+ explanation_metadata = builder.get_metadata_protobuf()
+
+ # To deploy a model to an endpoint with explanation
+ model.deploy(..., explanation_metadata=explanation_metadata)
+
+ # To deploy a model to a created endpoint with explanation
+ endpoint.deploy(..., explanation_metadata=explanation_metadata)
+
+ # To upload a model with explanation
+ aiplatform.Model.upload(..., explanation_metadata=explanation_metadata)
+
+
Next Steps
~~~~~~~~~~
-- Read the `Client Library Documentation`_ for Cloud AI Platform
+- Read the `Client Library Documentation`_ for Vertex AI
API to see other available methods on the client.
-- Read the `Cloud AI Platform API Product documentation`_ to learn
+- Read the `Vertex AI API Product documentation`_ to learn
more about the product and see How-to Guides.
- View this `README`_ to see the full list of Cloud
APIs that we cover.
-.. _Cloud AI Platform API Product documentation: https://cloud.google.com/ai-platform/docs
-.. _README: https://github.com/googleapis/google-cloud-python/blob/master/README.rst
\ No newline at end of file
+.. _Vertex AI API Product documentation: https://cloud.google.com/vertex-ai/docs
+.. _README: https://github.com/googleapis/google-cloud-python/blob/main/README.rst
diff --git a/docs/aiplatform/definition_v1/types.rst b/docs/aiplatform/definition_v1/types.rst
new file mode 100644
index 0000000000..a1df2bce25
--- /dev/null
+++ b/docs/aiplatform/definition_v1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1 Schema Trainingjob Definition v1 API
+=========================================================================
+
+.. automodule:: google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/definition_v1beta1/types.rst b/docs/aiplatform/definition_v1beta1/types.rst
new file mode 100644
index 0000000000..f4fe7a5301
--- /dev/null
+++ b/docs/aiplatform/definition_v1beta1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1beta1 Schema Trainingjob Definition v1beta1 API
+===================================================================================
+
+.. automodule:: google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/instance_v1/types.rst b/docs/aiplatform/instance_v1/types.rst
new file mode 100644
index 0000000000..564ab013ee
--- /dev/null
+++ b/docs/aiplatform/instance_v1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1 Schema Predict Instance v1 API
+===================================================================
+
+.. automodule:: google.cloud.aiplatform.v1.schema.predict.instance_v1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/instance_v1beta1/types.rst b/docs/aiplatform/instance_v1beta1/types.rst
new file mode 100644
index 0000000000..7caa088065
--- /dev/null
+++ b/docs/aiplatform/instance_v1beta1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1beta1 Schema Predict Instance v1beta1 API
+=============================================================================
+
+.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/params_v1/types.rst b/docs/aiplatform/params_v1/types.rst
new file mode 100644
index 0000000000..956ef5224d
--- /dev/null
+++ b/docs/aiplatform/params_v1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1 Schema Predict Params v1 API
+=================================================================
+
+.. automodule:: google.cloud.aiplatform.v1.schema.predict.params_v1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/params_v1beta1/types.rst b/docs/aiplatform/params_v1beta1/types.rst
new file mode 100644
index 0000000000..722a1d8ba0
--- /dev/null
+++ b/docs/aiplatform/params_v1beta1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1beta1 Schema Predict Params v1beta1 API
+===========================================================================
+
+.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/prediction_v1/types.rst b/docs/aiplatform/prediction_v1/types.rst
new file mode 100644
index 0000000000..a97faf34de
--- /dev/null
+++ b/docs/aiplatform/prediction_v1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1 Schema Predict Prediction v1 API
+=====================================================================
+
+.. automodule:: google.cloud.aiplatform.v1.schema.predict.prediction_v1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform/prediction_v1beta1/types.rst b/docs/aiplatform/prediction_v1beta1/types.rst
new file mode 100644
index 0000000000..b14182d6d7
--- /dev/null
+++ b/docs/aiplatform/prediction_v1beta1/types.rst
@@ -0,0 +1,7 @@
+Types for Google Cloud Aiplatform V1beta1 Schema Predict Prediction v1beta1 API
+===============================================================================
+
+.. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/aiplatform.rst b/docs/aiplatform/services.rst
similarity index 84%
rename from docs/aiplatform.rst
rename to docs/aiplatform/services.rst
index bf5cd4625b..0d21fe6bd1 100644
--- a/docs/aiplatform.rst
+++ b/docs/aiplatform/services.rst
@@ -3,4 +3,4 @@ Google Cloud Aiplatform SDK
.. automodule:: google.cloud.aiplatform
:members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/aiplatform/types.rst b/docs/aiplatform/types.rst
new file mode 100644
index 0000000000..119f762bca
--- /dev/null
+++ b/docs/aiplatform/types.rst
@@ -0,0 +1,13 @@
+Types for Google Cloud Aiplatform SDK API
+===========================================
+.. toctree::
+ :maxdepth: 2
+
+ instance_v1
+ instance_v1beta1
+ params_v1
+ params_v1beta1
+ prediction_v1
+ prediction_v1beta1
+ definition_v1
+ definition_v1beta1
diff --git a/docs/aiplatform_v1/featurestore_online_serving_service.rst b/docs/aiplatform_v1/featurestore_online_serving_service.rst
new file mode 100644
index 0000000000..ace5b9dd1a
--- /dev/null
+++ b/docs/aiplatform_v1/featurestore_online_serving_service.rst
@@ -0,0 +1,6 @@
+FeaturestoreOnlineServingService
+--------------------------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.featurestore_online_serving_service
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/featurestore_service.rst b/docs/aiplatform_v1/featurestore_service.rst
new file mode 100644
index 0000000000..90a303a4c4
--- /dev/null
+++ b/docs/aiplatform_v1/featurestore_service.rst
@@ -0,0 +1,10 @@
+FeaturestoreService
+-------------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.featurestore_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.featurestore_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/index_endpoint_service.rst b/docs/aiplatform_v1/index_endpoint_service.rst
new file mode 100644
index 0000000000..9a87b81082
--- /dev/null
+++ b/docs/aiplatform_v1/index_endpoint_service.rst
@@ -0,0 +1,10 @@
+IndexEndpointService
+--------------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.index_endpoint_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.index_endpoint_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/index_service.rst b/docs/aiplatform_v1/index_service.rst
new file mode 100644
index 0000000000..b07b444c23
--- /dev/null
+++ b/docs/aiplatform_v1/index_service.rst
@@ -0,0 +1,10 @@
+IndexService
+------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.index_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.index_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/metadata_service.rst b/docs/aiplatform_v1/metadata_service.rst
new file mode 100644
index 0000000000..419fd0a850
--- /dev/null
+++ b/docs/aiplatform_v1/metadata_service.rst
@@ -0,0 +1,10 @@
+MetadataService
+---------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.metadata_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.metadata_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/services.rst b/docs/aiplatform_v1/services.rst
index fd5a8c9aa7..0a6443a972 100644
--- a/docs/aiplatform_v1/services.rst
+++ b/docs/aiplatform_v1/services.rst
@@ -5,9 +5,16 @@ Services for Google Cloud Aiplatform v1 API
dataset_service
endpoint_service
+ featurestore_online_serving_service
+ featurestore_service
+ index_endpoint_service
+ index_service
job_service
+ metadata_service
migration_service
model_service
pipeline_service
prediction_service
specialist_pool_service
+ tensorboard_service
+ vizier_service
diff --git a/docs/aiplatform_v1/tensorboard_service.rst b/docs/aiplatform_v1/tensorboard_service.rst
new file mode 100644
index 0000000000..0fa17e10b8
--- /dev/null
+++ b/docs/aiplatform_v1/tensorboard_service.rst
@@ -0,0 +1,10 @@
+TensorboardService
+------------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.tensorboard_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.tensorboard_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/aiplatform_v1/vizier_service.rst b/docs/aiplatform_v1/vizier_service.rst
new file mode 100644
index 0000000000..efdbafe3c8
--- /dev/null
+++ b/docs/aiplatform_v1/vizier_service.rst
@@ -0,0 +1,10 @@
+VizierService
+-------------------------------
+
+.. automodule:: google.cloud.aiplatform_v1.services.vizier_service
+ :members:
+ :inherited-members:
+
+.. automodule:: google.cloud.aiplatform_v1.services.vizier_service.pagers
+ :members:
+ :inherited-members:
diff --git a/docs/conf.py b/docs/conf.py
index cd484b1e23..9be57dd7f7 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -76,13 +76,13 @@
# The encoding of source files.
# source_encoding = 'utf-8-sig'
-# The master toctree document.
-master_doc = "index"
+# The root toctree document.
+root_doc = "index"
# General information about the project.
-project = u"google-cloud-aiplatform"
-copyright = u"2019, Google"
-author = u"Google APIs"
+project = "google-cloud-aiplatform"
+copyright = "2019, Google"
+author = "Google APIs"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@@ -110,6 +110,7 @@
# directories to ignore when looking for source files.
exclude_patterns = [
"_build",
+ "**/.nox/**/*",
"samples/AUTHORING_GUIDE.md",
"samples/CONTRIBUTING.md",
"samples/snippets/README.rst",
@@ -279,9 +280,9 @@
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(
- master_doc,
+ root_doc,
"google-cloud-aiplatform.tex",
- u"google-cloud-aiplatform Documentation",
+ "google-cloud-aiplatform Documentation",
author,
"manual",
)
@@ -314,9 +315,9 @@
# (source start file, name, description, authors, manual section).
man_pages = [
(
- master_doc,
+ root_doc,
"google-cloud-aiplatform",
- u"google-cloud-aiplatform Documentation",
+ "google-cloud-aiplatform Documentation",
[author],
1,
)
@@ -333,9 +334,9 @@
# dir menu entry, description, category)
texinfo_documents = [
(
- master_doc,
+ root_doc,
"google-cloud-aiplatform",
- u"google-cloud-aiplatform Documentation",
+ "google-cloud-aiplatform Documentation",
author,
"google-cloud-aiplatform",
"google-cloud-aiplatform Library",
@@ -360,9 +361,13 @@
intersphinx_mapping = {
"python": ("https://python.readthedocs.org/en/latest/", None),
"google-auth": ("https://googleapis.dev/python/google-auth/latest/", None),
- "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,),
+ "google.api_core": (
+ "https://googleapis.dev/python/google-api-core/latest/",
+ None,
+ ),
"grpc": ("https://grpc.github.io/grpc/python/", None),
"proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None),
+ "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None),
}
diff --git a/docs/index.rst b/docs/index.rst
index 031271a261..6094720bd8 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -7,7 +7,9 @@ API Reference
.. toctree::
:maxdepth: 2
- aiplatform
+ aiplatform/services
+ aiplatform/types
+
aiplatform_v1/services
aiplatform_v1/types
@@ -22,4 +24,4 @@ For a list of all ``google-cloud-aiplatform`` releases:
.. toctree::
:maxdepth: 2
- changelog
\ No newline at end of file
+ changelog
diff --git a/docs/multiprocessing.rst b/docs/multiprocessing.rst
index 1cb29d4ca9..536d17b2ea 100644
--- a/docs/multiprocessing.rst
+++ b/docs/multiprocessing.rst
@@ -1,7 +1,7 @@
.. note::
- Because this client uses :mod:`grpcio` library, it is safe to
+ Because this client uses :mod:`grpc` library, it is safe to
share instances across threads. In multiprocessing scenarios, the best
practice is to create client instances *after* the invocation of
- :func:`os.fork` by :class:`multiprocessing.Pool` or
+ :func:`os.fork` by :class:`multiprocessing.pool.Pool` or
:class:`multiprocessing.Process`.
diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py
index 6aa8f64161..31f459d3f7 100644
--- a/google/cloud/aiplatform/__init__.py
+++ b/google/cloud/aiplatform/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,14 @@
# limitations under the License.
#
-from google.cloud.aiplatform import gapic
-from google.cloud.aiplatform import explain
+
+from google.cloud.aiplatform import version as aiplatform_version
+
+__version__ = aiplatform_version.__version__
+
from google.cloud.aiplatform import initializer
+
from google.cloud.aiplatform.datasets import (
ImageDataset,
TabularDataset,
@@ -26,25 +30,46 @@
TimeSeriesDataset,
VideoDataset,
)
+from google.cloud.aiplatform import explain
+from google.cloud.aiplatform import gapic
from google.cloud.aiplatform import hyperparameter_tuning
-from google.cloud.aiplatform.metadata import metadata
+from google.cloud.aiplatform.featurestore import (
+ EntityType,
+ Feature,
+ Featurestore,
+)
+from google.cloud.aiplatform.matching_engine import (
+ MatchingEngineIndex,
+ MatchingEngineIndexEndpoint,
+)
+from google.cloud.aiplatform import metadata
from google.cloud.aiplatform.models import Endpoint
from google.cloud.aiplatform.models import Model
+from google.cloud.aiplatform.model_evaluation import ModelEvaluation
from google.cloud.aiplatform.jobs import (
BatchPredictionJob,
CustomJob,
HyperparameterTuningJob,
)
+from google.cloud.aiplatform.pipeline_jobs import PipelineJob
+from google.cloud.aiplatform.tensorboard import (
+ Tensorboard,
+ TensorboardExperiment,
+ TensorboardRun,
+ TensorboardTimeSeries,
+)
from google.cloud.aiplatform.training_jobs import (
CustomTrainingJob,
CustomContainerTrainingJob,
CustomPythonPackageTrainingJob,
AutoMLTabularTrainingJob,
AutoMLForecastingTrainingJob,
+ SequenceToSequencePlusForecastingTrainingJob,
AutoMLImageTrainingJob,
AutoMLTextTrainingJob,
AutoMLVideoTrainingJob,
)
+from google.cloud.aiplatform import helpers
"""
Usage:
@@ -54,23 +79,39 @@
"""
init = initializer.global_config.init
-log_params = metadata.metadata_service.log_params
-log_metrics = metadata.metadata_service.log_metrics
-get_experiment_df = metadata.metadata_service.get_experiment_df
-get_pipeline_df = metadata.metadata_service.get_pipeline_df
-start_run = metadata.metadata_service.start_run
+get_pipeline_df = metadata.metadata._LegacyExperimentService.get_pipeline_df
+
+log_params = metadata.metadata._experiment_tracker.log_params
+log_metrics = metadata.metadata._experiment_tracker.log_metrics
+get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
+start_run = metadata.metadata._experiment_tracker.start_run
+start_execution = metadata.metadata._experiment_tracker.start_execution
+log = metadata.metadata._experiment_tracker.log
+log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
+end_run = metadata.metadata._experiment_tracker.end_run
+
+Experiment = metadata.experiment_resources.Experiment
+ExperimentRun = metadata.experiment_run_resource.ExperimentRun
+Artifact = metadata.artifact.Artifact
+Execution = metadata.execution.Execution
__all__ = (
+ "end_run",
"explain",
"gapic",
"init",
+ "helpers",
"hyperparameter_tuning",
+ "log",
"log_params",
"log_metrics",
+ "log_time_series_metrics",
"get_experiment_df",
"get_pipeline_df",
"start_run",
+ "start_execution",
+ "Artifact",
"AutoMLImageTrainingJob",
"AutoMLTabularTrainingJob",
"AutoMLForecastingTrainingJob",
@@ -82,10 +123,25 @@
"CustomContainerTrainingJob",
"CustomPythonPackageTrainingJob",
"Endpoint",
+ "EntityType",
+ "Execution",
+ "Experiment",
+ "ExperimentRun",
+ "Feature",
+ "Featurestore",
+ "MatchingEngineIndex",
+ "MatchingEngineIndexEndpoint",
"ImageDataset",
"HyperparameterTuningJob",
"Model",
+ "ModelEvaluation",
+ "PipelineJob",
+ "SequenceToSequencePlusForecastingTrainingJob",
"TabularDataset",
+ "Tensorboard",
+ "TensorboardExperiment",
+ "TensorboardRun",
+ "TensorboardTimeSeries",
"TextDataset",
"TimeSeriesDataset",
"VideoDataset",
diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py
index 07e4c2fe4a..ceb9287322 100644
--- a/google/cloud/aiplatform/base.py
+++ b/google/cloud/aiplatform/base.py
@@ -23,6 +23,7 @@
import logging
import sys
import threading
+import time
from typing import (
Any,
Callable,
@@ -38,25 +39,34 @@
import proto
+from google.api_core import retry
from google.api_core import operation
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec
+from google.protobuf import json_format
-
-logging.basicConfig(level=logging.INFO, stream=sys.stdout)
+# This is the default retry callback to be used with get methods.
+_DEFAULT_RETRY = retry.Retry()
class Logger:
"""Logging wrapper class with high level helper methods."""
- def __init__(self, name: str = ""):
- """Initializes logger with name.
+ def __init__(self, name: str):
+ """Initializes logger with optional name.
Args:
name (str): Name to associate with logger.
"""
self._logger = logging.getLogger(name)
+ self._logger.setLevel(logging.INFO)
+
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setLevel(logging.INFO)
+
+ self._logger.addHandler(handler)
def log_create_with_lro(
self,
@@ -92,7 +102,7 @@ def log_create_complete(
cls (VertexAiResourceNoun):
Vertex AI Resource Noun class that is being created.
resource (proto.Message):
- Vertex AI Resourc proto.Message
+ Vertex AI Resource proto.Message
variable_name (str): Name of variable to use for code snippet
"""
self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}")
@@ -115,7 +125,7 @@ def log_create_complete_with_getter(
cls (VertexAiResourceNoun):
Vertex AI Resource Noun class that is being created.
resource (proto.Message):
- Vertex AI Resourc proto.Message
+ Vertex AI Resource proto.Message
variable_name (str): Name of variable to use for code snippet
"""
self._logger.info(f"{cls.__name__} created. Resource name: {resource.name}")
@@ -391,7 +401,6 @@ class VertexAiResourceNoun(metaclass=abc.ABCMeta):
Subclasses require two class attributes:
client_class: The client to instantiate to interact with this resource noun.
- _is_client_prediction_client: Flag to indicate if the client requires a prediction endpoint.
Subclass is required to populate private attribute _gca_resource which is the
service representation of the resource noun.
@@ -408,29 +417,43 @@ def client_class(cls) -> Type[utils.VertexAiServiceClientWithOverride]:
@property
@classmethod
@abc.abstractmethod
- def _is_client_prediction_client(cls) -> bool:
- """Flag to indicate whether to use prediction endpoint with client."""
- pass
-
- @property
- @abc.abstractmethod
def _getter_method(cls) -> str:
"""Name of getter method of client class for retrieving the
resource."""
pass
@property
+ @classmethod
@abc.abstractmethod
def _delete_method(cls) -> str:
"""Name of delete method of client class for deleting the resource."""
pass
@property
+ @classmethod
@abc.abstractmethod
def _resource_noun(cls) -> str:
"""Resource noun."""
pass
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _parse_resource_name_method(cls) -> str:
+ """Method name on GAPIC client to parse a resource name."""
+ pass
+
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _format_resource_name_method(self) -> str:
+ """Method name on GAPIC client to format a resource name."""
+ pass
+
+ # Override this value with staticmethod
+ # to use custom resource id validators per resource
+ _resource_id_validator: Optional[Callable[[str], None]] = None
+
def __init__(
self,
project: Optional[str] = None,
@@ -443,7 +466,7 @@ def __init__(
Args:
project(str): Project of the resource noun.
location(str): The location of the resource noun.
- credentials(google.auth.crendentials.Crendentials): Optional custom
+ credentials(google.auth.credentials.Credentials): Optional custom
credentials to use when accessing interacting with resource noun.
resource_name(str): A fully-qualified resource name or ID.
"""
@@ -480,15 +503,48 @@ def _instantiate_client(
client_class=cls.client_class,
credentials=credentials,
location_override=location,
- prediction_client=cls._is_client_prediction_client,
)
+ @classmethod
+ def _parse_resource_name(cls, resource_name: str) -> Dict[str, str]:
+ """
+ Parses resource name into its component segments.
+
+ Args:
+ resource_name: Resource name of this resource.
+ Returns:
+ Dictionary of component segments.
+ """
+ # gets the underlying wrapped gapic client class
+ return getattr(
+ cls.client_class.get_gapic_client_class(), cls._parse_resource_name_method
+ )(resource_name)
+
+ @classmethod
+ def _format_resource_name(cls, **kwargs: str) -> str:
+ """
+ Formats a resource name using its component segments.
+
+ Args:
+ **kwargs: Resource name parts. Singular and snake case. ie:
+ format_resource_name(
+ project='my-project',
+ location='us-central1'
+ )
+ Returns:
+ Resource name.
+ """
+ # gets the underlying wrapped gapic client class
+ return getattr(
+ cls.client_class.get_gapic_client_class(), cls._format_resource_name_method
+ )(**kwargs)
+
def _get_and_validate_project_location(
self,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
- ) -> Tuple:
+ ) -> Tuple[str, str]:
"""Validate the project and location for the resource.
@@ -498,39 +554,50 @@ def _get_and_validate_project_location(
location(str): The location of the resource noun.
Raises:
- RuntimeError if location is different from resource location
+ RuntimeError: If location is different from resource location
"""
- fields = utils.extract_fields_from_resource_name(
- resource_name, self._resource_noun
- )
+ fields = self._parse_resource_name(resource_name)
+
if not fields:
return project, location
- if location and fields.location != location:
+ if location and fields["location"] != location:
raise RuntimeError(
f"location {location} is provided, but different from "
- f"the resource location {fields.location}"
+ f"the resource location {fields['location']}"
)
- return fields.project, fields.location
+ return fields["project"], fields["location"]
+
+ def _get_gca_resource(
+ self,
+ resource_name: str,
+ parent_resource_name_fields: Optional[Dict[str, str]] = None,
+ ) -> proto.Message:
+ """Returns GAPIC service representation of client class resource.
- def _get_gca_resource(self, resource_name: str) -> proto.Message:
- """Returns GAPIC service representation of client class resource."""
- """
Args:
- resource_name (str):
- Required. A fully-qualified resource name or ID.
+ resource_name (str): Required. A fully-qualified resource name or ID.
+ parent_resource_name_fields (Dict[str,str]):
+ Optional. Mapping of parent resource name key to values. These
+ will be used to compose the resource name if only resource ID is given.
+ Should not include project and location.
"""
-
resource_name = utils.full_resource_name(
resource_name=resource_name,
resource_noun=self._resource_noun,
+ parse_resource_name_method=self._parse_resource_name,
+ format_resource_name_method=self._format_resource_name,
project=self.project,
location=self.location,
+ parent_resource_name_fields=parent_resource_name_fields,
+ resource_id_validator=self._resource_id_validator,
)
- return getattr(self.api_client, self._getter_method)(name=resource_name)
+ return getattr(self.api_client, self._getter_method)(
+ name=resource_name, retry=_DEFAULT_RETRY
+ )
def _sync_gca_resource(self):
"""Sync GAPIC service representation of client class resource."""
@@ -540,21 +607,44 @@ def _sync_gca_resource(self):
@property
def name(self) -> str:
"""Name of this resource."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.name.split("/")[-1]
+ @property
+ def _project_tuple(self) -> Tuple[Optional[str], Optional[str]]:
+ """Returns the tuple of project id and project inferred from the local instance.
+
+ Another option is to use resource_manager_utils but requires the caller have resource manager
+ get role.
+ """
+ # we may not have the project if project inferred from the resource name
+ maybe_project_id = self.project
+ if self._gca_resource is not None and self._gca_resource.name:
+ project_no = self._parse_resource_name(self._gca_resource.name)["project"]
+ else:
+ project_no = None
+
+ if maybe_project_id == project_no:
+ return (None, project_no)
+ else:
+ return (maybe_project_id, project_no)
+
@property
def resource_name(self) -> str:
"""Full qualified resource name."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.name
@property
def display_name(self) -> str:
"""Display name of this resource."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.display_name
@property
def create_time(self) -> datetime.datetime:
"""Time this resource was created."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.create_time
@property
@@ -563,14 +653,65 @@ def update_time(self) -> datetime.datetime:
self._sync_gca_resource()
return self._gca_resource.update_time
+ @property
+ def encryption_spec(self) -> Optional[gca_encryption_spec.EncryptionSpec]:
+ """Customer-managed encryption key options for this Vertex AI resource.
+
+ If this is set, then all resources created by this Vertex AI resource will
+ be encrypted with the provided encryption key.
+ """
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "encryption_spec")
+
+ @property
+ def labels(self) -> Dict[str, str]:
+ """User-defined labels containing metadata about this resource.
+
+ Read more about labels at https://goo.gl/xmQnxf
+ """
+ self._assert_gca_resource_is_available()
+ return dict(self._gca_resource.labels)
+
@property
def gca_resource(self) -> proto.Message:
- """The underlying resource proto represenation."""
+ """The underlying resource proto representation."""
+ self._assert_gca_resource_is_available()
return self._gca_resource
+ @property
+ def _resource_is_available(self) -> bool:
+ """Returns True if GCA resource has been created and is available, otherwise False"""
+ try:
+ self._assert_gca_resource_is_available()
+ return True
+ except RuntimeError:
+ return False
+
+ def _assert_gca_resource_is_available(self) -> None:
+ """Helper method to raise when property is not accessible.
+
+ Raises:
+ RuntimeError: If _gca_resource is has not been created.
+ """
+ if self._gca_resource is None:
+ raise RuntimeError(
+ f"{self.__class__.__name__} resource has not been created"
+ )
+
def __repr__(self) -> str:
return f"{object.__repr__(self)} \nresource name: {self.resource_name}"
+ def to_dict(self) -> Dict[str, Any]:
+ """Returns the resource proto as a dictionary."""
+ return json_format.MessageToDict(self.gca_resource._pb)
+
+ @classmethod
+ def _generate_display_name(cls, prefix: Optional[str] = None) -> str:
+ """Returns a display name containing class name and time string."""
+ if not prefix:
+ prefix = cls.__name__
+ return prefix + " " + datetime.datetime.now().isoformat(sep=" ")
+
def optional_sync(
construct_object_on_arg: Optional[str] = None,
@@ -624,7 +765,7 @@ def wrapper(*args, **kwargs):
# if sync then wait for any Futures to complete and execute
if sync:
if self:
- self.wait()
+ VertexAiResourceNounWithFutureManager.wait(self)
return method(*args, **kwargs)
# callbacks to call within the Future (in same Thread)
@@ -639,17 +780,21 @@ def wrapper(*args, **kwargs):
inspect.getfullargspec(method).annotations["return"]
)
+ # object produced by the method
+ returned_object = bound_args.arguments.get(return_input_arg)
+
# is a classmethod that creates the object and returns it
if args and inspect.isclass(args[0]):
- # assumes classmethod is our resource noun
- returned_object = args[0]._empty_constructor()
+
+ # assumes class in classmethod is the resource noun
+ returned_object = (
+ args[0]._empty_constructor()
+ if not returned_object
+ else returned_object
+ )
self = returned_object
else: # instance method
-
- # object produced by the method
- returned_object = bound_args.arguments.get(return_input_arg)
-
# if we're returning an input object
if returned_object and returned_object is not self:
@@ -727,7 +872,7 @@ def __init__(
Args:
project (str): Optional. Project of the resource noun.
location (str): Optional. The location of the resource noun.
- credentials(google.auth.crendentials.Crendentials):
+ credentials(google.auth.credentials.Credentials):
Optional. custom credentials to use when accessing interacting with
resource noun.
resource_name(str): A fully-qualified resource name or ID.
@@ -757,7 +902,7 @@ def _empty_constructor(
Args:
project (str): Optional. Project of the resource noun.
location (str): Optional. The location of the resource noun.
- credentials(google.auth.crendentials.Crendentials):
+ credentials(google.auth.credentials.Credentials):
Optional. custom credentials to use when accessing interacting with
resource noun.
resource_name(str): A fully-qualified resource name or ID.
@@ -802,8 +947,9 @@ def _sync_object_with_future_result(
if value:
setattr(self, attribute, value)
+ @classmethod
def _construct_sdk_resource_from_gapic(
- self,
+ cls,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
@@ -813,7 +959,7 @@ def _construct_sdk_resource_from_gapic(
Args:
gapic_resource (proto.Message):
- A GAPIC representation of an Vertex AI resource, usually
+ A GAPIC representation of a Vertex AI resource, usually
retrieved by a get_* or in a list_* API call.
project (str):
Optional. Project to construct SDK object from. If not set,
@@ -829,7 +975,7 @@ def _construct_sdk_resource_from_gapic(
VertexAiResourceNoun:
An initialized SDK object that represents GAPIC type.
"""
- sdk_resource = self._empty_constructor(
+ sdk_resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
sdk_resource._gca_resource = gapic_resource
@@ -846,6 +992,7 @@ def _list(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ parent: Optional[str] = None,
) -> List[VertexAiResourceNoun]:
"""Private method to list all instances of this Vertex AI Resource,
takes a `cls_filter` arg to filter to a particular SDK resource
@@ -873,21 +1020,24 @@ def _list(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
+ parent (str):
+ Optional. The parent resource name if any to retrieve resource list from.
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""
- self = cls._empty_constructor(
+ resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
# Fetch credentials once and re-use for all `_empty_constructor()` calls
- creds = initializer.global_config.credentials
+ creds = resource.credentials
- resource_list_method = getattr(self.api_client, self._list_method)
+ resource_list_method = getattr(resource.api_client, resource._list_method)
list_request = {
- "parent": initializer.global_config.common_location_path(
+ "parent": parent
+ or initializer.global_config.common_location_path(
project=project, location=location
),
"filter": filter,
@@ -899,7 +1049,7 @@ def _list(
resource_list = resource_list_method(request=list_request) or []
return [
- self._construct_sdk_resource_from_gapic(
+ cls._construct_sdk_resource_from_gapic(
gapic_resource, project=project, location=location, credentials=creds
)
for gapic_resource in resource_list
@@ -977,6 +1127,7 @@ def list(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ parent: Optional[str] = None,
) -> List[VertexAiResourceNoun]:
"""List all instances of this Vertex AI Resource.
@@ -1005,6 +1156,8 @@ def list(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
+ parent (str):
+ Optional. The parent resource name if any to retrieve list from.
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
@@ -1016,12 +1169,13 @@ def list(
project=project,
location=location,
credentials=credentials,
+ parent=parent,
)
@optional_sync()
def delete(self, sync: bool = True) -> None:
"""Deletes this Vertex AI resource. WARNING: This deletion is
- permament.
+ permanent.
Args:
sync (bool):
@@ -1038,11 +1192,61 @@ def delete(self, sync: bool = True) -> None:
_LOGGER.log_action_completed_against_resource("deleted.", "", self)
def __repr__(self) -> str:
- if self._gca_resource:
+ if self._gca_resource and self._resource_is_available:
return VertexAiResourceNoun.__repr__(self)
return FutureManager.__repr__(self)
+ def _wait_for_resource_creation(self) -> None:
+ """Wait until underlying resource is created.
+
+ Currently this should only be used on subclasses that implement the construct then
+ `run` pattern because the underlying sync=False implementation will not update
+ downstream resource noun object's _gca_resource until the entire invoked method is complete.
+
+ Ex:
+ job = CustomTrainingJob()
+ job.run(sync=False, ...)
+ job._wait_for_resource_creation()
+ Raises:
+ RuntimeError: If the resource has not been scheduled to be created.
+ """
+
+ # If the user calls this but didn't actually invoke an API to create
+ if self._are_futures_done() and not getattr(self._gca_resource, "name", None):
+ self._raise_future_exception()
+ raise RuntimeError(
+ f"{self.__class__.__name__} resource is not scheduled to be created."
+ )
+
+ while not getattr(self._gca_resource, "name", None):
+ # breaks out of loop if creation has failed async
+ if self._are_futures_done() and not getattr(
+ self._gca_resource, "name", None
+ ):
+ self._raise_future_exception()
+
+ time.sleep(1)
+
+ def _assert_gca_resource_is_available(self) -> None:
+ """Helper method to raise when accessing properties that do not exist.
+
+ Overrides VertexAiResourceNoun to provide a more informative exception if
+ resource creation has failed asynchronously.
+
+ Raises:
+ RuntimeError: When resource has not been created.
+ """
+ if not getattr(self._gca_resource, "name", None):
+ raise RuntimeError(
+ f"{self.__class__.__name__} resource has not been created."
+ + (
+ f" Resource failed with: {self._exception}"
+ if self._exception
+ else ""
+ )
+ )
+
def get_annotation_class(annotation: type) -> type:
"""Helper method to retrieve type annotation.
@@ -1055,3 +1259,58 @@ def get_annotation_class(annotation: type) -> type:
return annotation.__args__[0]
else:
return annotation
+
+
+class DoneMixin(abc.ABC):
+ """An abstract class for implementing a done method, indicating
+ whether a job has completed.
+
+ """
+
+ @abc.abstractmethod
+ def done(self) -> bool:
+ """Method indicating whether a job has completed."""
+ pass
+
+
+class StatefulResource(DoneMixin):
+ """Extends DoneMixin to check whether a job returning a stateful resource has compted."""
+
+ @property
+ @abc.abstractmethod
+ def state(self):
+ """The current state of the job."""
+ pass
+
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _valid_done_states(cls):
+ """A set() containing all job states associated with a completed job."""
+ pass
+
+ def done(self) -> bool:
+ """Method indicating whether a job has completed.
+
+ Returns:
+ True if the job has completed.
+ """
+ if self.state in self._valid_done_states:
+ return True
+ else:
+ return False
+
+
+class VertexAiStatefulResource(VertexAiResourceNounWithFutureManager, StatefulResource):
+ """Extends StatefulResource to include a check for self._gca_resource."""
+
+ def done(self) -> bool:
+ """Method indicating whether a job has completed.
+
+ Returns:
+ True if the job has completed.
+ """
+ if self._gca_resource and self._gca_resource.name:
+ return super().done()
+ else:
+ return False
diff --git a/google/cloud/aiplatform/compat/__init__.py b/google/cloud/aiplatform/compat/__init__.py
index 55a72fea16..02e66ec494 100644
--- a/google/cloud/aiplatform/compat/__init__.py
+++ b/google/cloud/aiplatform/compat/__init__.py
@@ -27,6 +27,10 @@
services.dataset_service_client = services.dataset_service_client_v1beta1
services.endpoint_service_client = services.endpoint_service_client_v1beta1
+ services.featurestore_online_serving_service_client = (
+ services.featurestore_online_serving_service_client_v1beta1
+ )
+ services.featurestore_service_client = services.featurestore_service_client_v1beta1
services.job_service_client = services.job_service_client_v1beta1
services.model_service_client = services.model_service_client_v1beta1
services.pipeline_service_client = services.pipeline_service_client_v1beta1
@@ -36,12 +40,18 @@
)
services.metadata_service_client = services.metadata_service_client_v1beta1
services.tensorboard_service_client = services.tensorboard_service_client_v1beta1
+ services.index_service_client = services.index_service_client_v1beta1
+ services.index_endpoint_service_client = (
+ services.index_endpoint_service_client_v1beta1
+ )
types.accelerator_type = types.accelerator_type_v1beta1
types.annotation = types.annotation_v1beta1
types.annotation_spec = types.annotation_spec_v1beta1
+ types.artifact = types.artifact_v1beta1
types.batch_prediction_job = types.batch_prediction_job_v1beta1
types.completion_stats = types.completion_stats_v1beta1
+ types.context = types.context_v1beta1
types.custom_job = types.custom_job_v1beta1
types.data_item = types.data_item_v1beta1
types.data_labeling_job = types.data_labeling_job_v1beta1
@@ -51,50 +61,83 @@
types.encryption_spec = types.encryption_spec_v1beta1
types.endpoint = types.endpoint_v1beta1
types.endpoint_service = types.endpoint_service_v1beta1
+ types.entity_type = types.entity_type_v1beta1
types.env_var = types.env_var_v1beta1
+ types.event = types.event_v1beta1
+ types.execution = types.execution_v1beta1
types.explanation = types.explanation_v1beta1
types.explanation_metadata = types.explanation_metadata_v1beta1
+ types.feature = types.feature_v1beta1
+ types.feature_monitoring_stats = types.feature_monitoring_stats_v1beta1
+ types.feature_selector = types.feature_selector_v1beta1
+ types.featurestore = types.featurestore_v1beta1
+ types.featurestore_monitoring = types.featurestore_monitoring_v1beta1
+ types.featurestore_online_service = types.featurestore_online_service_v1beta1
+ types.featurestore_service = types.featurestore_service_v1beta1
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1beta1
+ types.index = types.index_v1beta1
+ types.index_endpoint = types.index_endpoint_v1beta1
types.io = types.io_v1beta1
types.job_service = types.job_service_v1beta1
types.job_state = types.job_state_v1beta1
+ types.lineage_subgraph = types.lineage_subgraph_v1beta1
types.machine_resources = types.machine_resources_v1beta1
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1beta1
+ types.matching_engine_deployed_index_ref = (
+ types.matching_engine_deployed_index_ref_v1beta1
+ )
+ types.matching_engine_index = types.index_v1beta1
+ types.matching_engine_index_endpoint = types.index_endpoint_v1beta1
+ types.metadata_service = types.metadata_service_v1beta1
+ types.metadata_schema = types.metadata_schema_v1beta1
+ types.metadata_store = types.metadata_store_v1beta1
types.model = types.model_v1beta1
types.model_evaluation = types.model_evaluation_v1beta1
types.model_evaluation_slice = types.model_evaluation_slice_v1beta1
types.model_service = types.model_service_v1beta1
types.operation = types.operation_v1beta1
+ types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
+ types.pipeline_job = types.pipeline_job_v1beta1
types.pipeline_service = types.pipeline_service_v1beta1
types.pipeline_state = types.pipeline_state_v1beta1
types.prediction_service = types.prediction_service_v1beta1
types.specialist_pool = types.specialist_pool_v1beta1
types.specialist_pool_service = types.specialist_pool_service_v1beta1
types.study = types.study_v1beta1
- types.training_pipeline = types.training_pipeline_v1beta1
- types.metadata_service = types.metadata_service_v1beta1
+ types.tensorboard = types.tensorboard_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_data = types.tensorboard_data_v1beta1
types.tensorboard_experiment = types.tensorboard_experiment_v1beta1
types.tensorboard_run = types.tensorboard_run_v1beta1
types.tensorboard_service = types.tensorboard_service_v1beta1
types.tensorboard_time_series = types.tensorboard_time_series_v1beta1
+ types.training_pipeline = types.training_pipeline_v1beta1
+ types.types = types.types_v1beta1
if DEFAULT_VERSION == V1:
services.dataset_service_client = services.dataset_service_client_v1
services.endpoint_service_client = services.endpoint_service_client_v1
+ services.featurestore_online_serving_service_client = (
+ services.featurestore_online_serving_service_client_v1
+ )
+ services.featurestore_service_client = services.featurestore_service_client_v1
services.job_service_client = services.job_service_client_v1
services.model_service_client = services.model_service_client_v1
services.pipeline_service_client = services.pipeline_service_client_v1
services.prediction_service_client = services.prediction_service_client_v1
services.specialist_pool_service_client = services.specialist_pool_service_client_v1
+ services.tensorboard_service_client = services.tensorboard_service_client_v1
+ services.index_service_client = services.index_service_client_v1
+ services.index_endpoint_service_client = services.index_endpoint_service_client_v1
types.accelerator_type = types.accelerator_type_v1
types.annotation = types.annotation_v1
types.annotation_spec = types.annotation_spec_v1
+ types.artifact = types.artifact_v1
types.batch_prediction_job = types.batch_prediction_job_v1
types.completion_stats = types.completion_stats_v1
+ types.context = types.context_v1
types.custom_job = types.custom_job_v1
types.data_item = types.data_item_v1
types.data_labeling_job = types.data_labeling_job_v1
@@ -104,25 +147,57 @@
types.encryption_spec = types.encryption_spec_v1
types.endpoint = types.endpoint_v1
types.endpoint_service = types.endpoint_service_v1
+ types.entity_type = types.entity_type_v1
types.env_var = types.env_var_v1
+ types.event = types.event_v1
+ types.execution = types.execution_v1
+ types.explanation = types.explanation_v1
+ types.explanation_metadata = types.explanation_metadata_v1
+ types.feature = types.feature_v1
+ types.feature_monitoring_stats = types.feature_monitoring_stats_v1
+ types.feature_selector = types.feature_selector_v1
+ types.featurestore = types.featurestore_v1
+ types.featurestore_online_service = types.featurestore_online_service_v1
+ types.featurestore_service = types.featurestore_service_v1
types.hyperparameter_tuning_job = types.hyperparameter_tuning_job_v1
+ types.index = types.index_v1
+ types.index_endpoint = types.index_endpoint_v1
types.io = types.io_v1
types.job_service = types.job_service_v1
types.job_state = types.job_state_v1
+ types.lineage_subgraph = types.lineage_subgraph_v1
types.machine_resources = types.machine_resources_v1
types.manual_batch_tuning_parameters = types.manual_batch_tuning_parameters_v1
+ types.matching_engine_deployed_index_ref = (
+ types.matching_engine_deployed_index_ref_v1
+ )
+ types.matching_engine_index = types.index_v1
+ types.matching_engine_index_endpoint = types.index_endpoint_v1
+ types.metadata_service = types.metadata_service_v1
+ types.metadata_schema = types.metadata_schema_v1
+ types.metadata_store = types.metadata_store_v1
types.model = types.model_v1
types.model_evaluation = types.model_evaluation_v1
types.model_evaluation_slice = types.model_evaluation_slice_v1
types.model_service = types.model_service_v1
types.operation = types.operation_v1
+ types.pipeline_failure_policy = types.pipeline_failure_policy_v1
+ types.pipeline_job = types.pipeline_job_v1
types.pipeline_service = types.pipeline_service_v1
types.pipeline_state = types.pipeline_state_v1
types.prediction_service = types.prediction_service_v1
types.specialist_pool = types.specialist_pool_v1
types.specialist_pool_service = types.specialist_pool_service_v1
types.study = types.study_v1
+ types.tensorboard = types.tensorboard_v1
+ types.tensorboard_service = types.tensorboard_service_v1
+ types.tensorboard_data = types.tensorboard_data_v1
+ types.tensorboard_experiment = types.tensorboard_experiment_v1
+ types.tensorboard_run = types.tensorboard_run_v1
+ types.tensorboard_service = types.tensorboard_service_v1
+ types.tensorboard_time_series = types.tensorboard_time_series_v1
types.training_pipeline = types.training_pipeline_v1
+ types.types = types.types_v1
__all__ = (
DEFAULT_VERSION,
diff --git a/google/cloud/aiplatform/compat/services/__init__.py b/google/cloud/aiplatform/compat/services/__init__.py
index 5c104ab41f..68440de4c5 100644
--- a/google/cloud/aiplatform/compat/services/__init__.py
+++ b/google/cloud/aiplatform/compat/services/__init__.py
@@ -21,9 +21,24 @@
from google.cloud.aiplatform_v1beta1.services.endpoint_service import (
client as endpoint_service_client_v1beta1,
)
+from google.cloud.aiplatform_v1beta1.services.featurestore_online_serving_service import (
+ client as featurestore_online_serving_service_client_v1beta1,
+)
+from google.cloud.aiplatform_v1beta1.services.featurestore_service import (
+ client as featurestore_service_client_v1beta1,
+)
+from google.cloud.aiplatform_v1beta1.services.index_service import (
+ client as index_service_client_v1beta1,
+)
+from google.cloud.aiplatform_v1beta1.services.index_endpoint_service import (
+ client as index_endpoint_service_client_v1beta1,
+)
from google.cloud.aiplatform_v1beta1.services.job_service import (
client as job_service_client_v1beta1,
)
+from google.cloud.aiplatform_v1beta1.services.metadata_service import (
+ client as metadata_service_client_v1beta1,
+)
from google.cloud.aiplatform_v1beta1.services.model_service import (
client as model_service_client_v1beta1,
)
@@ -36,9 +51,6 @@
from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1beta1,
)
-from google.cloud.aiplatform_v1beta1.services.metadata_service import (
- client as metadata_service_client_v1beta1,
-)
from google.cloud.aiplatform_v1beta1.services.tensorboard_service import (
client as tensorboard_service_client_v1beta1,
)
@@ -49,9 +61,24 @@
from google.cloud.aiplatform_v1.services.endpoint_service import (
client as endpoint_service_client_v1,
)
+from google.cloud.aiplatform_v1.services.featurestore_online_serving_service import (
+ client as featurestore_online_serving_service_client_v1,
+)
+from google.cloud.aiplatform_v1.services.featurestore_service import (
+ client as featurestore_service_client_v1,
+)
+from google.cloud.aiplatform_v1.services.index_service import (
+ client as index_service_client_v1,
+)
+from google.cloud.aiplatform_v1.services.index_endpoint_service import (
+ client as index_endpoint_service_client_v1,
+)
from google.cloud.aiplatform_v1.services.job_service import (
client as job_service_client_v1,
)
+from google.cloud.aiplatform_v1.services.metadata_service import (
+ client as metadata_service_client_v1,
+)
from google.cloud.aiplatform_v1.services.model_service import (
client as model_service_client_v1,
)
@@ -64,19 +91,32 @@
from google.cloud.aiplatform_v1.services.specialist_pool_service import (
client as specialist_pool_service_client_v1,
)
+from google.cloud.aiplatform_v1.services.tensorboard_service import (
+ client as tensorboard_service_client_v1,
+)
__all__ = (
# v1
dataset_service_client_v1,
endpoint_service_client_v1,
+ featurestore_online_serving_service_client_v1,
+ featurestore_service_client_v1,
+ index_service_client_v1,
+ index_endpoint_service_client_v1,
job_service_client_v1,
+ metadata_service_client_v1,
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
specialist_pool_service_client_v1,
+ tensorboard_service_client_v1,
# v1beta1
dataset_service_client_v1beta1,
endpoint_service_client_v1beta1,
+ featurestore_online_serving_service_client_v1beta1,
+ featurestore_service_client_v1beta1,
+ index_service_client_v1beta1,
+ index_endpoint_service_client_v1beta1,
job_service_client_v1beta1,
model_service_client_v1beta1,
pipeline_service_client_v1beta1,
diff --git a/google/cloud/aiplatform/compat/types/__init__.py b/google/cloud/aiplatform/compat/types/__init__.py
index 7bd512e7e8..25c7515877 100644
--- a/google/cloud/aiplatform/compat/types/__init__.py
+++ b/google/cloud/aiplatform/compat/types/__init__.py
@@ -19,80 +19,130 @@
accelerator_type as accelerator_type_v1beta1,
annotation as annotation_v1beta1,
annotation_spec as annotation_spec_v1beta1,
+ artifact as artifact_v1beta1,
batch_prediction_job as batch_prediction_job_v1beta1,
completion_stats as completion_stats_v1beta1,
+ context as context_v1beta1,
custom_job as custom_job_v1beta1,
data_item as data_item_v1beta1,
data_labeling_job as data_labeling_job_v1beta1,
dataset as dataset_v1beta1,
dataset_service as dataset_service_v1beta1,
+ deployed_index_ref as matching_engine_deployed_index_ref_v1beta1,
deployed_model_ref as deployed_model_ref_v1beta1,
encryption_spec as encryption_spec_v1beta1,
endpoint as endpoint_v1beta1,
endpoint_service as endpoint_service_v1beta1,
+ entity_type as entity_type_v1beta1,
env_var as env_var_v1beta1,
+ event as event_v1beta1,
+ execution as execution_v1beta1,
explanation as explanation_v1beta1,
explanation_metadata as explanation_metadata_v1beta1,
+ feature as feature_v1beta1,
+ feature_monitoring_stats as feature_monitoring_stats_v1beta1,
+ feature_selector as feature_selector_v1beta1,
+ featurestore as featurestore_v1beta1,
+ featurestore_monitoring as featurestore_monitoring_v1beta1,
+ featurestore_online_service as featurestore_online_service_v1beta1,
+ featurestore_service as featurestore_service_v1beta1,
+ index as index_v1beta1,
+ index_endpoint as index_endpoint_v1beta1,
hyperparameter_tuning_job as hyperparameter_tuning_job_v1beta1,
io as io_v1beta1,
job_service as job_service_v1beta1,
job_state as job_state_v1beta1,
+ lineage_subgraph as lineage_subgraph_v1beta1,
machine_resources as machine_resources_v1beta1,
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1beta1,
+ metadata_schema as metadata_schema_v1beta1,
+ metadata_service as metadata_service_v1beta1,
+ metadata_store as metadata_store_v1beta1,
model as model_v1beta1,
model_evaluation as model_evaluation_v1beta1,
model_evaluation_slice as model_evaluation_slice_v1beta1,
model_service as model_service_v1beta1,
operation as operation_v1beta1,
+ pipeline_failure_policy as pipeline_failure_policy_v1beta1,
+ pipeline_job as pipeline_job_v1beta1,
pipeline_service as pipeline_service_v1beta1,
pipeline_state as pipeline_state_v1beta1,
prediction_service as prediction_service_v1beta1,
specialist_pool as specialist_pool_v1beta1,
specialist_pool_service as specialist_pool_service_v1beta1,
study as study_v1beta1,
- training_pipeline as training_pipeline_v1beta1,
- metadata_service as metadata_service_v1beta1,
- tensorboard_service as tensorboard_service_v1beta1,
+ tensorboard as tensorboard_v1beta1,
tensorboard_data as tensorboard_data_v1beta1,
tensorboard_experiment as tensorboard_experiment_v1beta1,
tensorboard_run as tensorboard_run_v1beta1,
tensorboard_service as tensorboard_service_v1beta1,
tensorboard_time_series as tensorboard_time_series_v1beta1,
+ training_pipeline as training_pipeline_v1beta1,
+ types as types_v1beta1,
)
from google.cloud.aiplatform_v1.types import (
accelerator_type as accelerator_type_v1,
annotation as annotation_v1,
annotation_spec as annotation_spec_v1,
+ artifact as artifact_v1,
batch_prediction_job as batch_prediction_job_v1,
completion_stats as completion_stats_v1,
+ context as context_v1,
custom_job as custom_job_v1,
data_item as data_item_v1,
data_labeling_job as data_labeling_job_v1,
dataset as dataset_v1,
dataset_service as dataset_service_v1,
+ deployed_index_ref as matching_engine_deployed_index_ref_v1,
deployed_model_ref as deployed_model_ref_v1,
encryption_spec as encryption_spec_v1,
endpoint as endpoint_v1,
endpoint_service as endpoint_service_v1,
+ entity_type as entity_type_v1,
env_var as env_var_v1,
+ event as event_v1,
+ execution as execution_v1,
+ explanation as explanation_v1,
+ explanation_metadata as explanation_metadata_v1,
+ feature as feature_v1,
+ feature_monitoring_stats as feature_monitoring_stats_v1,
+ feature_selector as feature_selector_v1,
+ featurestore as featurestore_v1,
+ featurestore_online_service as featurestore_online_service_v1,
+ featurestore_service as featurestore_service_v1,
hyperparameter_tuning_job as hyperparameter_tuning_job_v1,
+ index as index_v1,
+ index_endpoint as index_endpoint_v1,
io as io_v1,
job_service as job_service_v1,
job_state as job_state_v1,
+ lineage_subgraph as lineage_subgraph_v1,
machine_resources as machine_resources_v1,
manual_batch_tuning_parameters as manual_batch_tuning_parameters_v1,
+ metadata_service as metadata_service_v1,
+ metadata_schema as metadata_schema_v1,
+ metadata_store as metadata_store_v1,
model as model_v1,
model_evaluation as model_evaluation_v1,
model_evaluation_slice as model_evaluation_slice_v1,
model_service as model_service_v1,
operation as operation_v1,
+ pipeline_failure_policy as pipeline_failure_policy_v1,
+ pipeline_job as pipeline_job_v1,
pipeline_service as pipeline_service_v1,
pipeline_state as pipeline_state_v1,
prediction_service as prediction_service_v1,
specialist_pool as specialist_pool_v1,
specialist_pool_service as specialist_pool_service_v1,
study as study_v1,
+ tensorboard as tensorboard_v1,
+ tensorboard_data as tensorboard_data_v1,
+ tensorboard_experiment as tensorboard_experiment_v1,
+ tensorboard_run as tensorboard_run_v1,
+ tensorboard_service as tensorboard_service_v1,
+ tensorboard_time_series as tensorboard_time_series_v1,
training_pipeline as training_pipeline_v1,
+ types as types_v1,
)
__all__ = (
@@ -100,8 +150,10 @@
accelerator_type_v1,
annotation_v1,
annotation_spec_v1,
+ artifact_v1,
batch_prediction_job_v1,
completion_stats_v1,
+ context_v1,
custom_job_v1,
data_item_v1,
data_labeling_job_v1,
@@ -111,30 +163,59 @@
encryption_spec_v1,
endpoint_v1,
endpoint_service_v1,
+ entity_type_v1,
env_var_v1,
+ event_v1,
+ execution_v1,
+ explanation_v1,
+ explanation_metadata_v1,
+ feature_v1,
+ feature_monitoring_stats_v1,
+ feature_selector_v1,
+ featurestore_v1,
+ featurestore_online_service_v1,
+ featurestore_service_v1,
hyperparameter_tuning_job_v1,
io_v1,
job_service_v1,
job_state_v1,
+ lineage_subgraph_v1,
machine_resources_v1,
manual_batch_tuning_parameters_v1,
+ matching_engine_deployed_index_ref_v1,
+ index_v1,
+ index_endpoint_v1,
+ metadata_service_v1,
+ metadata_schema_v1,
+ metadata_store_v1,
model_v1,
model_evaluation_v1,
model_evaluation_slice_v1,
model_service_v1,
operation_v1,
+ pipeline_failure_policy_v1beta1,
+ pipeline_job_v1,
pipeline_service_v1,
pipeline_state_v1,
prediction_service_v1,
specialist_pool_v1,
specialist_pool_service_v1,
+ tensorboard_v1,
+ tensorboard_data_v1,
+ tensorboard_experiment_v1,
+ tensorboard_run_v1,
+ tensorboard_service_v1,
+ tensorboard_time_series_v1,
training_pipeline_v1,
+ types_v1,
# v1beta1
accelerator_type_v1beta1,
annotation_v1beta1,
annotation_spec_v1beta1,
+ artifact_v1beta1,
batch_prediction_job_v1beta1,
completion_stats_v1beta1,
+ context_v1beta1,
custom_job_v1beta1,
data_item_v1beta1,
data_labeling_job_v1beta1,
@@ -144,31 +225,50 @@
encryption_spec_v1beta1,
endpoint_v1beta1,
endpoint_service_v1beta1,
+ entity_type_v1beta1,
env_var_v1beta1,
+ event_v1beta1,
+ execution_v1beta1,
explanation_v1beta1,
explanation_metadata_v1beta1,
+ feature_v1beta1,
+ feature_monitoring_stats_v1beta1,
+ feature_selector_v1beta1,
+ featurestore_v1beta1,
+ featurestore_monitoring_v1beta1,
+ featurestore_online_service_v1beta1,
+ featurestore_service_v1beta1,
hyperparameter_tuning_job_v1beta1,
io_v1beta1,
job_service_v1beta1,
job_state_v1beta1,
+ lineage_subgraph_v1beta1,
machine_resources_v1beta1,
manual_batch_tuning_parameters_v1beta1,
+ matching_engine_deployed_index_ref_v1beta1,
+ index_v1beta1,
+ index_endpoint_v1beta1,
+ metadata_service_v1beta1,
+ metadata_schema_v1beta1,
+ metadata_store_v1beta1,
model_v1beta1,
model_evaluation_v1beta1,
model_evaluation_slice_v1beta1,
model_service_v1beta1,
operation_v1beta1,
+ pipeline_failure_policy_v1beta1,
+ pipeline_job_v1beta1,
pipeline_service_v1beta1,
pipeline_state_v1beta1,
prediction_service_v1beta1,
specialist_pool_v1beta1,
specialist_pool_service_v1beta1,
- training_pipeline_v1beta1,
- metadata_service_v1beta1,
- tensorboard_service_v1beta1,
+ tensorboard_v1beta1,
tensorboard_data_v1beta1,
tensorboard_experiment_v1beta1,
tensorboard_run_v1beta1,
tensorboard_service_v1beta1,
tensorboard_time_series_v1beta1,
+ training_pipeline_v1beta1,
+ types_v1beta1,
)
diff --git a/google/cloud/aiplatform/constants/__init__.py b/google/cloud/aiplatform/constants/__init__.py
new file mode 100644
index 0000000000..95f437a335
--- /dev/null
+++ b/google/cloud/aiplatform/constants/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.cloud.aiplatform.constants import base
+from google.cloud.aiplatform.constants import prediction
+
+__all__ = ("base", "prediction")
diff --git a/google/cloud/aiplatform/constants.py b/google/cloud/aiplatform/constants/base.py
similarity index 92%
rename from google/cloud/aiplatform/constants.py
rename to google/cloud/aiplatform/constants/base.py
index a7d81084cd..230918d564 100644
--- a/google/cloud/aiplatform/constants.py
+++ b/google/cloud/aiplatform/constants/base.py
@@ -18,21 +18,30 @@
DEFAULT_REGION = "us-central1"
SUPPORTED_REGIONS = {
"asia-east1",
+ "asia-east2",
"asia-northeast1",
"asia-northeast3",
+ "asia-south1",
"asia-southeast1",
"australia-southeast1",
"europe-west1",
"europe-west2",
+ "europe-west3",
"europe-west4",
+ "europe-west6",
"northamerica-northeast1",
+ "northamerica-northeast2",
"us-central1",
"us-east1",
"us-east4",
"us-west1",
+ "us-west2",
+ "us-west4",
+ "southamerica-east1",
}
API_BASE_PATH = "aiplatform.googleapis.com"
+PREDICTION_API_BASE_PATH = API_BASE_PATH
# Batch Prediction
BATCH_PREDICTION_INPUT_STORAGE_FORMATS = (
diff --git a/google/cloud/aiplatform/constants/prediction.py b/google/cloud/aiplatform/constants/prediction.py
new file mode 100644
index 0000000000..8bfdf12655
--- /dev/null
+++ b/google/cloud/aiplatform/constants/prediction.py
@@ -0,0 +1,147 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+from collections import defaultdict
+
+# [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest
+CONTAINER_URI_PATTERN = re.compile(
+ r"(?P[\w]+)\-docker\.pkg\.dev\/vertex\-ai\/prediction\/"
+ r"(?P[\w]+)\-(?P[\w]+)\.(?P[\d-]+):latest"
+)
+
+SKLEARN = "sklearn"
+TF = "tf"
+TF2 = "tf2"
+XGBOOST = "xgboost"
+
+XGBOOST_CONTAINER_URIS = [
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-5:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-4:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-3:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-2:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.1-1:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-90:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/xgboost-cpu.0-82:latest",
+]
+
+SKLEARN_CONTAINER_URIS = [
+ "us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.1-0:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-24:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-23:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-22:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/sklearn-cpu.0-20:latest",
+]
+
+TF_CONTAINER_URIS = [
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-8:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-8:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-7:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-7:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-5:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-5:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-4:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-4:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-3:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-3:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-2:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-2:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-1:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf-cpu.1-15:latest",
+ "us-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
+ "europe-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
+ "asia-docker.pkg.dev/vertex-ai/prediction/tf-gpu.1-15:latest",
+]
+
+SERVING_CONTAINER_URIS = (
+ SKLEARN_CONTAINER_URIS + TF_CONTAINER_URIS + XGBOOST_CONTAINER_URIS
+)
+
+# Map of all first-party prediction containers
+d = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(str))))
+
+for container_uri in SERVING_CONTAINER_URIS:
+ m = CONTAINER_URI_PATTERN.match(container_uri)
+ region, framework, accelerator, version = m[1], m[2], m[3], m[4]
+ version = version.replace("-", ".")
+
+ if framework in (TF2, TF): # Store both `tf`, `tf2` as `tensorflow`
+ framework = "tensorflow"
+
+ d[region][framework][accelerator][version] = container_uri
+
+_SERVING_CONTAINER_URI_MAP = d
+
+_SERVING_CONTAINER_DOCUMENTATION_URL = (
+ "https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers"
+)
diff --git a/google/cloud/aiplatform/datasets/__init__.py b/google/cloud/aiplatform/datasets/__init__.py
index b297530955..0f6b7f42fa 100644
--- a/google/cloud/aiplatform/datasets/__init__.py
+++ b/google/cloud/aiplatform/datasets/__init__.py
@@ -16,6 +16,7 @@
#
from google.cloud.aiplatform.datasets.dataset import _Dataset
+from google.cloud.aiplatform.datasets.column_names_dataset import _ColumnNamesDataset
from google.cloud.aiplatform.datasets.tabular_dataset import TabularDataset
from google.cloud.aiplatform.datasets.time_series_dataset import TimeSeriesDataset
from google.cloud.aiplatform.datasets.image_dataset import ImageDataset
@@ -25,6 +26,7 @@
__all__ = (
"_Dataset",
+ "_ColumnNamesDataset",
"TabularDataset",
"TimeSeriesDataset",
"ImageDataset",
diff --git a/google/cloud/aiplatform/datasets/_datasources.py b/google/cloud/aiplatform/datasets/_datasources.py
index 9323f40382..5fc51c03f6 100644
--- a/google/cloud/aiplatform/datasets/_datasources.py
+++ b/google/cloud/aiplatform/datasets/_datasources.py
@@ -71,7 +71,7 @@ def __init__(
"bq://project.dataset.table_name"
Raises:
- ValueError if source configuration is not valid.
+ ValueError: If source configuration is not valid.
"""
dataset_metadata = None
@@ -121,10 +121,9 @@ def __init__(
Args:
gcs_source (Union[str, Sequence[str]]):
Required. The Google Cloud Storage location for the input content.
- Google Cloud Storage URI(-s) to the input file(s). May contain
- wildcards. For more information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ Google Cloud Storage URI(-s) to the input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
@@ -185,10 +184,9 @@ def create_datasource(
`OpenAPI 3.0.2 Schema
gcs_source (Union[str, Sequence[str]]):
The Google Cloud Storage location for the input content.
- Google Cloud Storage URI(-s) to the input file(s). May contain
- wildcards. For more information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ Google Cloud Storage URI(-s) to the input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
@@ -215,7 +213,7 @@ def create_datasource(
datasource (Datasource)
Raises:
- ValueError when below scenarios happen
+ ValueError: When below scenarios happen:
- import_schema_uri is identified for creating TabularDatasource
- either import_schema_uri or gcs_source is missing for creating NonTabularDatasourceImportable
"""
diff --git a/google/cloud/aiplatform/datasets/column_names_dataset.py b/google/cloud/aiplatform/datasets/column_names_dataset.py
new file mode 100644
index 0000000000..27783d6c80
--- /dev/null
+++ b/google/cloud/aiplatform/datasets/column_names_dataset.py
@@ -0,0 +1,256 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import csv
+import logging
+from typing import List, Optional, Set
+from google.auth import credentials as auth_credentials
+
+from google.cloud import bigquery
+from google.cloud import storage
+
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform import datasets
+
+
+class _ColumnNamesDataset(datasets._Dataset):
+ @property
+ def column_names(self) -> List[str]:
+ """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
+ Google BigQuery source.
+
+ Returns:
+ List[str]
+ A list of columns names
+
+ Raises:
+ RuntimeError: When no valid source is found.
+ """
+
+ self._assert_gca_resource_is_available()
+
+ metadata = self._gca_resource.metadata
+
+ if metadata is None:
+ raise RuntimeError("No metadata found for dataset")
+
+ input_config = metadata.get("inputConfig")
+
+ if input_config is None:
+ raise RuntimeError("No inputConfig found for dataset")
+
+ gcs_source = input_config.get("gcsSource")
+ bq_source = input_config.get("bigquerySource")
+
+ if gcs_source:
+ gcs_source_uris = gcs_source.get("uri")
+
+ if gcs_source_uris and len(gcs_source_uris) > 0:
+ # Lexicographically sort the files
+ gcs_source_uris.sort()
+
+ # Get the first file in sorted list
+ # TODO(b/193044977): Return as Set instead of List
+ return list(
+ self._retrieve_gcs_source_columns(
+ project=self.project,
+ gcs_csv_file_path=gcs_source_uris[0],
+ credentials=self.credentials,
+ )
+ )
+ elif bq_source:
+ bq_table_uri = bq_source.get("uri")
+ if bq_table_uri:
+ # TODO(b/193044977): Return as Set instead of List
+ return list(
+ self._retrieve_bq_source_columns(
+ project=self.project,
+ bq_table_uri=bq_table_uri,
+ credentials=self.credentials,
+ )
+ )
+
+ raise RuntimeError("No valid CSV or BigQuery datasource found.")
+
+ @staticmethod
+ def _retrieve_gcs_source_columns(
+ project: str,
+ gcs_csv_file_path: str,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> Set[str]:
+ """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
+
+ Example Usage:
+
+ column_names = _retrieve_gcs_source_columns(
+ "project_id",
+ "gs://example-bucket/path/to/csv_file"
+ )
+
+ # column_names = {"column_1", "column_2"}
+
+ Args:
+ project (str):
+ Required. Project to initiate the Google Cloud Storage client with.
+ gcs_csv_file_path (str):
+ Required. A full path to a CSV files stored on Google Cloud Storage.
+ Must include "gs://" prefix.
+ credentials (auth_credentials.Credentials):
+ Credentials to use to with GCS Client.
+ Returns:
+ Set[str]
+ A set of columns names in the CSV file.
+
+ Raises:
+ RuntimeError: When the retrieved CSV file is invalid.
+ """
+
+ gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
+ gcs_csv_file_path
+ )
+ client = storage.Client(project=project, credentials=credentials)
+ bucket = client.bucket(gcs_bucket)
+ blob = bucket.blob(gcs_blob)
+
+ # Incrementally download the CSV file until the header is retrieved
+ first_new_line_index = -1
+ start_index = 0
+ increment = 1000
+ line = ""
+
+ try:
+ logger = logging.getLogger("google.resumable_media._helpers")
+ logging_warning_filter = utils.LoggingFilter(logging.INFO)
+ logger.addFilter(logging_warning_filter)
+
+ while first_new_line_index == -1:
+ line += blob.download_as_bytes(
+ start=start_index, end=start_index + increment - 1
+ ).decode("utf-8")
+
+ first_new_line_index = line.find("\n")
+ start_index += increment
+
+ header_line = line[:first_new_line_index]
+
+ # Split to make it an iterable
+ header_line = header_line.split("\n")[:1]
+
+ csv_reader = csv.reader(header_line, delimiter=",")
+ except (ValueError, RuntimeError) as err:
+ raise RuntimeError(
+ "There was a problem extracting the headers from the CSV file at '{}': {}".format(
+ gcs_csv_file_path, err
+ )
+ )
+ finally:
+ logger.removeFilter(logging_warning_filter)
+
+ return set(next(csv_reader))
+
+ @staticmethod
+ def _get_bq_schema_field_names_recursively(
+ schema_field: bigquery.SchemaField,
+ ) -> Set[str]:
+ """Retrieve the name for a schema field along with ancestor fields.
+ Nested schema fields are flattened and concatenated with a ".".
+ Schema fields with child fields are not included, but the children are.
+
+ Args:
+ project (str):
+ Required. Project to initiate the BigQuery client with.
+ bq_table_uri (str):
+ Required. A URI to a BigQuery table.
+ Can include "bq://" prefix but not required.
+ credentials (auth_credentials.Credentials):
+ Credentials to use with BQ Client.
+
+ Returns:
+ Set[str]
+ A set of columns names in the BigQuery table.
+ """
+
+ ancestor_names = {
+ nested_field_name
+ for field in schema_field.fields
+ for nested_field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
+ field
+ )
+ }
+
+ # Only return "leaf nodes", basically any field that doesn't have children
+ if len(ancestor_names) == 0:
+ return {schema_field.name}
+ else:
+ return {f"{schema_field.name}.{name}" for name in ancestor_names}
+
+ @staticmethod
+ def _retrieve_bq_source_columns(
+ project: str,
+ bq_table_uri: str,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> Set[str]:
+ """Retrieve the column names from a table on Google BigQuery
+ Nested schema fields are flattened and concatenated with a ".".
+ Schema fields with child fields are not included, but the children are.
+
+ Example Usage:
+
+ column_names = _retrieve_bq_source_columns(
+ "project_id",
+ "bq://project_id.dataset.table"
+ )
+
+ # column_names = {"column_1", "column_2", "column_3.nested_field"}
+
+ Args:
+ project (str):
+ Required. Project to initiate the BigQuery client with.
+ bq_table_uri (str):
+ Required. A URI to a BigQuery table.
+ Can include "bq://" prefix but not required.
+ credentials (auth_credentials.Credentials):
+ Credentials to use with BQ Client.
+
+ Returns:
+ Set[str]
+ A set of column names in the BigQuery table.
+ """
+
+ # Remove bq:// prefix
+ prefix = "bq://"
+ if bq_table_uri.startswith(prefix):
+ bq_table_uri = bq_table_uri[len(prefix) :]
+
+ # The colon-based "project:dataset.table" format is no longer supported:
+ # Invalid dataset ID "bigquery-public-data:chicago_taxi_trips".
+ # Dataset IDs must be alphanumeric (plus underscores and dashes) and must be at most 1024 characters long.
+ # Using dot-based "project.dataset.table" format instead.
+ bq_table_uri = bq_table_uri.replace(":", ".")
+
+ client = bigquery.Client(project=project, credentials=credentials)
+ table = client.get_table(bq_table_uri)
+ schema = table.schema
+
+ return {
+ field_name
+ for field in schema
+ for field_name in _ColumnNamesDataset._get_bq_schema_field_names_recursively(
+ field
+ )
+ }
diff --git a/google/cloud/aiplatform/datasets/dataset.py b/google/cloud/aiplatform/datasets/dataset.py
index 1eb1663b2b..508932779b 100644
--- a/google/cloud/aiplatform/datasets/dataset.py
+++ b/google/cloud/aiplatform/datasets/dataset.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Optional, Sequence, Dict, Tuple, Union, List
+from typing import Dict, List, Optional, Sequence, Tuple, Union
from google.api_core import operation
from google.auth import credentials as auth_credentials
@@ -31,6 +31,7 @@
io as gca_io,
)
from google.cloud.aiplatform.datasets import _datasources
+from google.protobuf import field_mask_pb2
_LOGGER = base.Logger(__name__)
@@ -39,11 +40,12 @@ class _Dataset(base.VertexAiResourceNounWithFutureManager):
"""Managed dataset resource for Vertex AI."""
client_class = utils.DatasetClientWithOverride
- _is_client_prediction_client = False
_resource_noun = "datasets"
_getter_method = "get_dataset"
_list_method = "list_datasets"
_delete_method = "delete_dataset"
+ _parse_resource_name_method = "parse_dataset_path"
+ _format_resource_name_method = "dataset_path"
_supported_metadata_schema_uris: Tuple[str] = ()
@@ -68,7 +70,7 @@ def __init__(
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
- Custom credentials to use to upload this model. Overrides
+ Custom credentials to use to retrieve this Dataset. Overrides
credentials set in aiplatform.init.
"""
@@ -84,13 +86,14 @@ def __init__(
@property
def metadata_schema_uri(self) -> str:
"""The metadata schema uri of this dataset resource."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.metadata_schema_uri
def _validate_metadata_schema_uri(self) -> None:
"""Validate the metadata_schema_uri of retrieved dataset resource.
Raises:
- ValueError if the dataset type of the retrieved dataset resource is
+ ValueError: If the dataset type of the retrieved dataset resource is
not supported by the class.
"""
if self._supported_metadata_schema_uris and (
@@ -104,6 +107,7 @@ def _validate_metadata_schema_uri(self) -> None:
@classmethod
def create(
cls,
+ # TODO(b/223262536): Make the display_name parameter optional in the next major release
display_name: str,
metadata_schema_uri: str,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
@@ -114,8 +118,10 @@ def create(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Creates a new dataset and optionally imports data into dataset when
source and import_schema_uri are passed.
@@ -161,7 +167,7 @@ def create(
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
- labels specified inside index file refenced by
+ labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
project (str):
@@ -175,6 +181,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -190,13 +206,18 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
dataset (Dataset):
Instantiated representation of the managed dataset resource.
"""
-
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -220,10 +241,12 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@classmethod
@@ -239,8 +262,11 @@ def _create_and_import(
location: str,
credentials: Optional[auth_credentials.Credentials],
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ import_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Creates a new dataset and optionally imports data into dataset when
source and import_schema_uri are passed.
@@ -276,6 +302,16 @@ def _create_and_import(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
The key needs to be in the same region as where the compute
@@ -286,6 +322,10 @@ def _create_and_import(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ import_request_timeout (float):
+ Optional. The timeout for the import request in seconds.
Returns:
dataset (Dataset):
@@ -299,7 +339,9 @@ def _create_and_import(
metadata_schema_uri=metadata_schema_uri,
datasource=datasource,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=encryption_spec,
+ create_request_timeout=create_request_timeout,
)
_LOGGER.log_create_with_lro(cls, create_dataset_lro)
@@ -317,16 +359,26 @@ def _create_and_import(
# Import if import datasource is DatasourceImportable
if isinstance(datasource, _datasources.DatasourceImportable):
- dataset_obj._import_and_wait(datasource)
+ dataset_obj._import_and_wait(
+ datasource, import_request_timeout=import_request_timeout
+ )
return dataset_obj
- def _import_and_wait(self, datasource):
+ def _import_and_wait(
+ self,
+ datasource,
+ import_request_timeout: Optional[float] = None,
+ ):
_LOGGER.log_action_start_against_resource(
- "Importing", "data", self,
+ "Importing",
+ "data",
+ self,
)
- import_lro = self._import(datasource=datasource)
+ import_lro = self._import(
+ datasource=datasource, import_request_timeout=import_request_timeout
+ )
_LOGGER.log_action_started_against_resource_with_lro(
"Import", "data", self.__class__, import_lro
@@ -345,7 +397,9 @@ def _create(
metadata_schema_uri: str,
datasource: _datasources.Datasource,
request_metadata: Sequence[Tuple[str, str]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
+ create_request_timeout: Optional[float] = None,
) -> operation.Operation:
"""Creates a new managed dataset by directly calling API client.
@@ -372,12 +426,24 @@ def _create(
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the create_dataset
request as metadata. Usually to specify special dataset config.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec (Optional[gca_encryption_spec.EncryptionSpec]):
Optional. The Cloud KMS customer managed encryption key used to protect the dataset.
The key needs to be in the same region as where the compute
resource is created.
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
operation (Operation):
An object representing a long-running operation.
@@ -387,28 +453,38 @@ def _create(
display_name=display_name,
metadata_schema_uri=metadata_schema_uri,
metadata=datasource.dataset_metadata,
+ labels=labels,
encryption_spec=encryption_spec,
)
return api_client.create_dataset(
- parent=parent, dataset=gapic_dataset, metadata=request_metadata
+ parent=parent,
+ dataset=gapic_dataset,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
)
def _import(
- self, datasource: _datasources.DatasourceImportable,
+ self,
+ datasource: _datasources.DatasourceImportable,
+ import_request_timeout: Optional[float] = None,
) -> operation.Operation:
"""Imports data into managed dataset by directly calling API client.
Args:
datasource (_datasources.DatasourceImportable):
Required. Datasource for importing data to an existing dataset for Vertex AI.
+ import_request_timeout (float):
+ Optional. The timeout for the import request in seconds.
Returns:
operation (Operation):
An object representing a long-running operation.
"""
return self.api_client.import_data(
- name=self.resource_name, import_configs=[datasource.import_data_config]
+ name=self.resource_name,
+ import_configs=[datasource.import_data_config],
+ timeout=import_request_timeout,
)
@base.optional_sync(return_input_arg="self")
@@ -418,6 +494,7 @@ def import_data(
import_schema_uri: str,
data_item_labels: Optional[Dict] = None,
sync: bool = True,
+ import_request_timeout: Optional[float] = None,
) -> "_Dataset":
"""Upload data to existing managed dataset.
@@ -448,13 +525,15 @@ def import_data(
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
- labels specified inside index file refenced by
+ labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ import_request_timeout (float):
+ Optional. The timeout for the import request in seconds.
Returns:
dataset (Dataset):
@@ -467,7 +546,9 @@ def import_data(
data_item_labels=data_item_labels,
)
- self._import_and_wait(datasource=datasource)
+ self._import_and_wait(
+ datasource=datasource, import_request_timeout=import_request_timeout
+ )
return self
# TODO(b/174751568) add optional sync support
@@ -517,8 +598,69 @@ def export_data(self, output_dir: str) -> Sequence[str]:
return export_data_response.exported_files
- def update(self):
- raise NotImplementedError("Update dataset has not been implemented yet")
+ def update(
+ self,
+ *,
+ display_name: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ description: Optional[str] = None,
+ update_request_timeout: Optional[float] = None,
+ ) -> "_Dataset":
+ """Update the dataset.
+ Updatable fields:
+ - ``display_name``
+ - ``description``
+ - ``labels``
+
+ Args:
+ display_name (str):
+ Optional. The user-defined name of the Dataset.
+ The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
+ description (str):
+ Optional. The description of the Dataset.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ dataset (Dataset):
+ Updated dataset.
+ """
+
+ update_mask = field_mask_pb2.FieldMask()
+ if display_name:
+ update_mask.paths.append("display_name")
+
+ if labels:
+ update_mask.paths.append("labels")
+
+ if description:
+ update_mask.paths.append("description")
+
+ update_dataset = gca_dataset.Dataset(
+ name=self.resource_name,
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ )
+
+ self._gca_resource = self.api_client.update_dataset(
+ dataset=update_dataset,
+ update_mask=update_mask,
+ timeout=update_request_timeout,
+ )
+
+ return self
@classmethod
def list(
diff --git a/google/cloud/aiplatform/datasets/image_dataset.py b/google/cloud/aiplatform/datasets/image_dataset.py
index 506338c915..b7f7954278 100644
--- a/google/cloud/aiplatform/datasets/image_dataset.py
+++ b/google/cloud/aiplatform/datasets/image_dataset.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Optional, Sequence, Dict, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
@@ -36,7 +36,7 @@ class ImageDataset(datasets._Dataset):
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
@@ -44,23 +44,24 @@ def create(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "ImageDataset":
"""Creates a new image dataset and optionally imports data into dataset
when source and import_schema_uri are passed.
Args:
display_name (str):
- Required. The user-defined name of the Dataset.
+ Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
- input file(s). May contain wildcards. For more
- information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
@@ -81,7 +82,7 @@ def create(
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
- labels specified inside index file refenced by
+ labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
project (str):
@@ -95,6 +96,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -110,13 +121,19 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
image_dataset (ImageDataset):
Instantiated representation of the managed image dataset resource.
"""
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -141,8 +158,10 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
)
diff --git a/google/cloud/aiplatform/datasets/tabular_dataset.py b/google/cloud/aiplatform/datasets/tabular_dataset.py
index 95f1b16f98..f4366e4a24 100644
--- a/google/cloud/aiplatform/datasets/tabular_dataset.py
+++ b/google/cloud/aiplatform/datasets/tabular_dataset.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,207 +15,57 @@
# limitations under the License.
#
-import csv
-import logging
-
-from typing import List, Optional, Sequence, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
from google.cloud import bigquery
-from google.cloud import storage
-
+from google.cloud.aiplatform import base
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform.datasets import _datasources
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
+_AUTOML_TRAINING_MIN_ROWS = 1000
+
+_LOGGER = base.Logger(__name__)
+
-class TabularDataset(datasets._Dataset):
+class TabularDataset(datasets._ColumnNamesDataset):
"""Managed tabular dataset resource for Vertex AI."""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
schema.dataset.metadata.tabular,
)
- @property
- def column_names(self) -> List[str]:
- """Retrieve the columns for the dataset by extracting it from the Google Cloud Storage or
- Google BigQuery source.
-
- Returns:
- List[str]
- A list of columns names
-
- Raises:
- RuntimeError: When no valid source is found.
- """
-
- metadata = self._gca_resource.metadata
-
- if metadata is None:
- raise RuntimeError("No metadata found for dataset")
-
- input_config = metadata.get("inputConfig")
-
- if input_config is None:
- raise RuntimeError("No inputConfig found for dataset")
-
- gcs_source = input_config.get("gcsSource")
- bq_source = input_config.get("bigquerySource")
-
- if gcs_source:
- gcs_source_uris = gcs_source.get("uri")
-
- if gcs_source_uris and len(gcs_source_uris) > 0:
- # Lexicographically sort the files
- gcs_source_uris.sort()
-
- # Get the first file in sorted list
- return TabularDataset._retrieve_gcs_source_columns(
- self.project, gcs_source_uris[0]
- )
- elif bq_source:
- bq_table_uri = bq_source.get("uri")
- if bq_table_uri:
- return TabularDataset._retrieve_bq_source_columns(
- self.project, bq_table_uri
- )
-
- raise RuntimeError("No valid CSV or BigQuery datasource found.")
-
- @staticmethod
- def _retrieve_gcs_source_columns(project: str, gcs_csv_file_path: str) -> List[str]:
- """Retrieve the columns from a comma-delimited CSV file stored on Google Cloud Storage
-
- Example Usage:
-
- column_names = _retrieve_gcs_source_columns(
- "project_id",
- "gs://example-bucket/path/to/csv_file"
- )
-
- # column_names = ["column_1", "column_2"]
-
- Args:
- project (str):
- Required. Project to initiate the Google Cloud Storage client with.
- gcs_csv_file_path (str):
- Required. A full path to a CSV files stored on Google Cloud Storage.
- Must include "gs://" prefix.
-
- Returns:
- List[str]
- A list of columns names in the CSV file.
-
- Raises:
- RuntimeError: When the retrieved CSV file is invalid.
- """
-
- gcs_bucket, gcs_blob = utils.extract_bucket_and_prefix_from_gcs_path(
- gcs_csv_file_path
- )
- client = storage.Client(project=project)
- bucket = client.bucket(gcs_bucket)
- blob = bucket.blob(gcs_blob)
-
- # Incrementally download the CSV file until the header is retrieved
- first_new_line_index = -1
- start_index = 0
- increment = 1000
- line = ""
-
- try:
- logger = logging.getLogger("google.resumable_media._helpers")
- logging_warning_filter = utils.LoggingFilter(logging.INFO)
- logger.addFilter(logging_warning_filter)
-
- while first_new_line_index == -1:
- line += blob.download_as_bytes(
- start=start_index, end=start_index + increment
- ).decode("utf-8")
- first_new_line_index = line.find("\n")
- start_index += increment
-
- header_line = line[:first_new_line_index]
-
- # Split to make it an iterable
- header_line = header_line.split("\n")[:1]
-
- csv_reader = csv.reader(header_line, delimiter=",")
- except (ValueError, RuntimeError) as err:
- raise RuntimeError(
- "There was a problem extracting the headers from the CSV file at '{}': {}".format(
- gcs_csv_file_path, err
- )
- )
- finally:
- logger.removeFilter(logging_warning_filter)
-
- return next(csv_reader)
-
- @staticmethod
- def _retrieve_bq_source_columns(project: str, bq_table_uri: str) -> List[str]:
- """Retrieve the columns from a table on Google BigQuery
-
- Example Usage:
-
- column_names = _retrieve_bq_source_columns(
- "project_id",
- "bq://project_id.dataset.table"
- )
-
- # column_names = ["column_1", "column_2"]
-
- Args:
- project (str):
- Required. Project to initiate the BigQuery client with.
- bq_table_uri (str):
- Required. A URI to a BigQuery table.
- Can include "bq://" prefix but not required.
-
- Returns:
- List[str]
- A list of columns names in the BigQuery table.
- """
-
- # Remove bq:// prefix
- prefix = "bq://"
- if bq_table_uri.startswith(prefix):
- bq_table_uri = bq_table_uri[len(prefix) :]
-
- client = bigquery.Client(project=project)
- table = client.get_table(bq_table_uri)
- schema = table.schema
- return [schema.name for schema in schema]
-
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "TabularDataset":
"""Creates a new tabular dataset.
Args:
display_name (str):
- Required. The user-defined name of the Dataset.
+ Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
- input file(s). May contain wildcards. For more
- information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
@@ -233,6 +83,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -248,13 +108,18 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
tabular_dataset (TabularDataset):
Instantiated representation of the managed tabular dataset resource.
"""
-
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -278,12 +143,120 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
+
+ @classmethod
+ def create_from_dataframe(
+ cls,
+ df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
+ staging_path: str,
+ bq_schema: Optional[Union[str, bigquery.SchemaField]] = None,
+ display_name: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "TabularDataset":
+ """Creates a new tabular dataset from a Pandas DataFrame.
+
+ Args:
+ df_source (pd.DataFrame):
+ Required. Pandas DataFrame containing the source data for
+ ingestion as a TabularDataset. This method will use the data
+ types from the provided DataFrame when creating the dataset.
+ staging_path (str):
+ Required. The BigQuery table to stage the data
+ for Vertex. Because Vertex maintains a reference to this source
+ to create the Vertex Dataset, this BigQuery table should
+ not be deleted. Example: `bq://my-project.my-dataset.my-table`.
+ If the provided BigQuery table doesn't exist, this method will
+ create the table. If the provided BigQuery table already exists,
+ and the schemas of the BigQuery table and your DataFrame match,
+ this method will append the data in your local DataFrame to the table.
+ The location of the provided BigQuery table should conform to the location requirements
+ specified here: https://cloud.google.com/vertex-ai/docs/general/locations#bq-locations.
+ bq_schema (Optional[Union[str, bigquery.SchemaField]]):
+ Optional. If not set, BigQuery will autodetect the schema using your DataFrame's column types.
+ If set, BigQuery will use the schema you provide when creating the staging table. For more details,
+ see: https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.LoadJobConfig#google_cloud_bigquery_job_LoadJobConfig_schema
+ display_name (str):
+ Optional. The user-defined name of the Dataset.
+ The name can be up to 128 characters long and can be consist
+ of any UTF-8 charact
+ project (str):
+ Optional. Project to upload this dataset to. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to upload this dataset to. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this dataset. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ tabular_dataset (TabularDataset):
+ Instantiated representation of the managed tabular dataset resource.
+ """
+
+ if staging_path.startswith("bq://"):
+ bq_staging_path = staging_path[len("bq://") :]
+ else:
+ raise ValueError(
+ "Only BigQuery staging paths are supported. Provide a staging path in the format `bq://your-project.your-dataset.your-table`."
+ )
+
+ try:
+ import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
+ except ImportError:
+ raise ImportError(
+ "Pyarrow is not installed, and is required to use the BigQuery client."
+ 'Please install the SDK using "pip install google-cloud-aiplatform[datasets]"'
+ )
+
+ if len(df_source) < _AUTOML_TRAINING_MIN_ROWS:
+ _LOGGER.info(
+ "Your DataFrame has %s rows and AutoML requires %s rows to train on tabular data. You can still train a custom model once your dataset has been uploaded to Vertex, but you will not be able to use AutoML for training."
+ % (len(df_source), _AUTOML_TRAINING_MIN_ROWS),
+ )
+
+ bigquery_client = bigquery.Client(
+ project=project or initializer.global_config.project,
+ credentials=credentials or initializer.global_config.credentials,
)
+ try:
+ parquet_options = bigquery.format_options.ParquetOptions()
+ parquet_options.enable_list_inference = True
+
+ job_config = bigquery.LoadJobConfig(
+ source_format=bigquery.SourceFormat.PARQUET,
+ parquet_options=parquet_options,
+ )
+
+ if bq_schema:
+ job_config.schema = bq_schema
+
+ job = bigquery_client.load_table_from_dataframe(
+ dataframe=df_source, destination=bq_staging_path, job_config=job_config
+ )
+
+ job.result()
+
+ finally:
+ dataset_from_dataframe = cls.create(
+ display_name=display_name,
+ bq_source=staging_path,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return dataset_from_dataframe
+
def import_data(self):
raise NotImplementedError(
f"{self.__class__.__name__} class does not support 'import_data'"
diff --git a/google/cloud/aiplatform/datasets/text_dataset.py b/google/cloud/aiplatform/datasets/text_dataset.py
index 85676ed2ed..f74fb76bb7 100644
--- a/google/cloud/aiplatform/datasets/text_dataset.py
+++ b/google/cloud/aiplatform/datasets/text_dataset.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Optional, Sequence, Dict, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
@@ -36,7 +36,7 @@ class TextDataset(datasets._Dataset):
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
@@ -44,8 +44,10 @@ def create(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "TextDataset":
"""Creates a new text dataset and optionally imports data into dataset
when source and import_schema_uri are passed.
@@ -59,15 +61,14 @@ def create(
Args:
display_name (str):
- Required. The user-defined name of the Dataset.
+ Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
- input file(s). May contain wildcards. For more
- information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
@@ -88,7 +89,7 @@ def create(
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
- labels specified inside index file refenced by
+ labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
project (str):
@@ -102,6 +103,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -113,6 +124,8 @@ def create(
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
@@ -122,8 +135,11 @@ def create(
text_dataset (TextDataset):
Instantiated representation of the managed text dataset resource.
"""
-
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -148,8 +164,10 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
)
diff --git a/google/cloud/aiplatform/datasets/time_series_dataset.py b/google/cloud/aiplatform/datasets/time_series_dataset.py
index d5aa3dcbf2..6cc48e2558 100644
--- a/google/cloud/aiplatform/datasets/time_series_dataset.py
+++ b/google/cloud/aiplatform/datasets/time_series_dataset.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Optional, Sequence, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
@@ -26,7 +26,7 @@
from google.cloud.aiplatform import utils
-class TimeSeriesDataset(datasets._Dataset):
+class TimeSeriesDataset(datasets._ColumnNamesDataset):
"""Managed time series dataset resource for Vertex AI"""
_supported_metadata_schema_uris: Optional[Tuple[str]] = (
@@ -36,29 +36,30 @@ class TimeSeriesDataset(datasets._Dataset):
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bq_source: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "TimeSeriesDataset":
- """Creates a new tabular dataset.
+ """Creates a new time series dataset.
Args:
display_name (str):
- Required. The user-defined name of the Dataset.
+ Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
- input file(s). May contain wildcards. For more
- information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
bq_source (str):
@@ -76,6 +77,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -91,14 +102,19 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
time_series_dataset (TimeSeriesDataset):
Instantiated representation of the managed time series dataset resource.
"""
-
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -122,10 +138,12 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
)
def import_data(self):
diff --git a/google/cloud/aiplatform/datasets/video_dataset.py b/google/cloud/aiplatform/datasets/video_dataset.py
index 594a4ac407..bef719b17b 100644
--- a/google/cloud/aiplatform/datasets/video_dataset.py
+++ b/google/cloud/aiplatform/datasets/video_dataset.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from typing import Optional, Sequence, Dict, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
from google.auth import credentials as auth_credentials
@@ -36,7 +36,7 @@ class VideoDataset(datasets._Dataset):
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
import_schema_uri: Optional[str] = None,
data_item_labels: Optional[Dict] = None,
@@ -44,23 +44,24 @@ def create(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "VideoDataset":
"""Creates a new video dataset and optionally imports data into dataset
when source and import_schema_uri are passed.
Args:
display_name (str):
- Required. The user-defined name of the Dataset.
+ Optional. The user-defined name of the Dataset.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source (Union[str, Sequence[str]]):
Google Cloud Storage URI(-s) to the
- input file(s). May contain wildcards. For more
- information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
- examples:
+ input file(s).
+
+ Examples:
str: "gs://bucket/file.csv"
Sequence[str]: ["gs://bucket/file1.csv", "gs://bucket/file2.csv"]
import_schema_uri (str):
@@ -81,7 +82,7 @@ def create(
be picked randomly. Two DataItems are considered identical
if their content bytes are identical (e.g. image bytes or
pdf bytes). These labels will be overridden by Annotation
- labels specified inside index file refenced by
+ labels specified inside index file referenced by
``import_schema_uri``,
e.g. jsonl file.
project (str):
@@ -95,6 +96,16 @@ def create(
credentials set in aiplatform.init.
request_metadata (Sequence[Tuple[str, str]]):
Strings which should be sent along with the request as metadata.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the dataset. Has the
@@ -106,6 +117,8 @@ def create(
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
@@ -115,8 +128,11 @@ def create(
video_dataset (VideoDataset):
Instantiated representation of the managed video dataset resource.
"""
-
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
api_client = cls._instantiate_client(location=location, credentials=credentials)
@@ -141,8 +157,10 @@ def create(
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
request_metadata=request_metadata,
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
)
diff --git a/google/cloud/aiplatform/explain/__init__.py b/google/cloud/aiplatform/explain/__init__.py
index 61b9181834..4701d709b5 100644
--- a/google/cloud/aiplatform/explain/__init__.py
+++ b/google/cloud/aiplatform/explain/__init__.py
@@ -16,11 +16,11 @@
#
from google.cloud.aiplatform.compat.types import (
- explanation_metadata_v1beta1 as explanation_metadata,
- explanation_v1beta1 as explanation,
+ explanation as explanation_compat,
+ explanation_metadata as explanation_metadata_compat,
)
-ExplanationMetadata = explanation_metadata.ExplanationMetadata
+ExplanationMetadata = explanation_metadata_compat.ExplanationMetadata
# ExplanationMetadata subclasses
InputMetadata = ExplanationMetadata.InputMetadata
@@ -32,15 +32,14 @@
Visualization = InputMetadata.Visualization
-ExplanationParameters = explanation.ExplanationParameters
-FeatureNoiseSigma = explanation.FeatureNoiseSigma
+ExplanationParameters = explanation_compat.ExplanationParameters
+FeatureNoiseSigma = explanation_compat.FeatureNoiseSigma
# Classes used by ExplanationParameters
-IntegratedGradientsAttribution = explanation.IntegratedGradientsAttribution
-
-SampledShapleyAttribution = explanation.SampledShapleyAttribution
-SmoothGradConfig = explanation.SmoothGradConfig
-XraiAttribution = explanation.XraiAttribution
+IntegratedGradientsAttribution = explanation_compat.IntegratedGradientsAttribution
+SampledShapleyAttribution = explanation_compat.SampledShapleyAttribution
+SmoothGradConfig = explanation_compat.SmoothGradConfig
+XraiAttribution = explanation_compat.XraiAttribution
__all__ = (
diff --git a/google/cloud/aiplatform/explain/lit.py b/google/cloud/aiplatform/explain/lit.py
new file mode 100644
index 0000000000..6d388f4559
--- /dev/null
+++ b/google/cloud/aiplatform/explain/lit.py
@@ -0,0 +1,482 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+
+from google.cloud import aiplatform
+from typing import Dict, List, Mapping, Optional, Tuple, Union
+
+try:
+ from lit_nlp.api import dataset as lit_dataset
+ from lit_nlp.api import dtypes as lit_dtypes
+ from lit_nlp.api import model as lit_model
+ from lit_nlp.api import types as lit_types
+ from lit_nlp import notebook
+except ImportError:
+ raise ImportError(
+ "LIT is not installed and is required to get Dataset as the return format. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
+ )
+
+try:
+ import tensorflow as tf
+except ImportError:
+ raise ImportError(
+ "Tensorflow is not installed and is required to load saved model. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[lit]"'
+ )
+
+try:
+ import pandas as pd
+except ImportError:
+ raise ImportError(
+ "Pandas is not installed and is required to read the dataset. "
+ 'Please install Pandas using "pip install google-cloud-aiplatform[lit]"'
+ )
+
+
+class _VertexLitDataset(lit_dataset.Dataset):
+ """LIT dataset class for the Vertex LIT integration.
+
+ This is used in the create_lit_dataset function.
+ """
+
+ def __init__(
+ self,
+ dataset: pd.DataFrame,
+ column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ ):
+ """Construct a VertexLitDataset.
+ Args:
+ dataset:
+ Required. A Pandas DataFrame that includes feature column names and data.
+ column_types:
+ Required. An OrderedDict of string names matching the columns of the dataset
+ as the key, and the associated LitType of the column.
+ """
+ self._examples = dataset.to_dict(orient="records")
+ self._column_types = column_types
+
+ def spec(self):
+ """Return a spec describing dataset elements."""
+ return dict(self._column_types)
+
+
+class _EndpointLitModel(lit_model.Model):
+ """LIT model class for the Vertex LIT integration with a model deployed to an endpoint.
+
+ This is used in the create_lit_model function.
+ """
+
+ def __init__(
+ self,
+ endpoint: Union[str, aiplatform.Endpoint],
+ input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ model_id: Optional[str] = None,
+ ):
+ """Construct a VertexLitModel.
+ Args:
+ model:
+ Required. The name of the Endpoint resource. Format:
+ ``projects/{project}/locations/{location}/endpoints/{endpoint}``
+ input_types:
+ Required. An OrderedDict of string names matching the features of the model
+ as the key, and the associated LitType of the feature.
+ output_types:
+ Required. An OrderedDict of string names matching the labels of the model
+ as the key, and the associated LitType of the label.
+ model_id:
+ Optional. A string of the specific model in the endpoint to create the
+ LIT model from. If this is not set, any usable model in the endpoint is
+ used to create the LIT model.
+ Raises:
+ ValueError if the model_id was not found in the endpoint.
+ """
+ if isinstance(endpoint, str):
+ self._endpoint = aiplatform.Endpoint(endpoint)
+ else:
+ self._endpoint = endpoint
+ self._model_id = model_id
+ self._input_types = input_types
+ self._output_types = output_types
+ # Check if the model with the model ID has explanation enabled
+ if model_id:
+ deployed_model = next(
+ filter(
+ lambda model: model.id == model_id, self._endpoint.list_models()
+ ),
+ None,
+ )
+ if not deployed_model:
+ raise ValueError(
+ "A model with id {model_id} was not found in the endpoint {endpoint}.".format(
+ model_id=model_id, endpoint=endpoint
+ )
+ )
+ self._explanation_enabled = bool(deployed_model.explanation_spec)
+ # Check if all models in the endpoint have explanation enabled
+ else:
+ self._explanation_enabled = all(
+ model.explanation_spec for model in self._endpoint.list_models()
+ )
+
+ def predict_minibatch(
+ self, inputs: List[lit_types.JsonDict]
+ ) -> List[lit_types.JsonDict]:
+ """Retun predictions based on a batch of inputs.
+ Args:
+ inputs: Requred. a List of instances to predict on based on the input spec.
+ Returns:
+ A list of predictions based on the output spec.
+ """
+ instances = []
+ for input in inputs:
+ instance = [input[feature] for feature in self._input_types]
+ instances.append(instance)
+ if self._explanation_enabled:
+ prediction_object = self._endpoint.explain(instances)
+ else:
+ prediction_object = self._endpoint.predict(instances)
+ outputs = []
+ for prediction in prediction_object.predictions:
+ if isinstance(prediction, Mapping):
+ outputs.append({key: prediction[key] for key in self._output_types})
+ else:
+ outputs.append(
+ {key: prediction[i] for i, key in enumerate(self._output_types)}
+ )
+ if self._explanation_enabled:
+ for i, explanation in enumerate(prediction_object.explanations):
+ attributions = explanation.attributions
+ outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
+ attributions
+ )
+ return outputs
+
+ def input_spec(self) -> lit_types.Spec:
+ """Return a spec describing model inputs."""
+ return dict(self._input_types)
+
+ def output_spec(self) -> lit_types.Spec:
+ """Return a spec describing model outputs."""
+ output_spec_dict = dict(self._output_types)
+ if self._explanation_enabled:
+ output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
+ signed=True
+ )
+ return output_spec_dict
+
+
+class _TensorFlowLitModel(lit_model.Model):
+ """LIT model class for the Vertex LIT integration with a TensorFlow saved model.
+
+ This is used in the create_lit_model function.
+ """
+
+ def __init__(
+ self,
+ model: str,
+ input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ attribution_method: str = "sampled_shapley",
+ ):
+ """Construct a VertexLitModel.
+ Args:
+ model:
+ Required. A string reference to a local TensorFlow saved model directory.
+ The model must have at most one input and one output tensor.
+ input_types:
+ Required. An OrderedDict of string names matching the features of the model
+ as the key, and the associated LitType of the feature.
+ output_types:
+ Required. An OrderedDict of string names matching the labels of the model
+ as the key, and the associated LitType of the label.
+ attribution_method:
+ Optional. A string to choose what attribution configuration to
+ set up the explainer with. Valid options are 'sampled_shapley'
+ or 'integrated_gradients'.
+ """
+ self._load_model(model)
+ self._input_types = input_types
+ self._output_types = output_types
+ self._input_tensor_name = next(iter(self._kwargs_signature))
+ self._attribution_explainer = None
+ if os.environ.get("LIT_PROXY_URL"):
+ self._set_up_attribution_explainer(model, attribution_method)
+
+ @property
+ def attribution_explainer(
+ self,
+ ) -> Optional["AttributionExplainer"]: # noqa: F821
+ """Gets the attribution explainer property if set."""
+ return self._attribution_explainer
+
+ def predict_minibatch(
+ self, inputs: List[lit_types.JsonDict]
+ ) -> List[lit_types.JsonDict]:
+ """Retun predictions based on a batch of inputs.
+ Args:
+ inputs: Requred. a List of instances to predict on based on the input spec.
+ Returns:
+ A list of predictions based on the output spec.
+ """
+ instances = []
+ for input in inputs:
+ instance = [input[feature] for feature in self._input_types]
+ instances.append(instance)
+ prediction_input_dict = {
+ self._input_tensor_name: tf.convert_to_tensor(instances)
+ }
+ prediction_dict = self._loaded_model.signatures[
+ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ ](**prediction_input_dict)
+ predictions = prediction_dict[next(iter(self._output_signature))].numpy()
+ outputs = []
+ for prediction in predictions:
+ outputs.append(
+ {
+ label: value
+ for label, value in zip(self._output_types.keys(), prediction)
+ }
+ )
+ # Get feature attributions
+ if self.attribution_explainer:
+ attributions = self.attribution_explainer.explain(
+ [{self._input_tensor_name: i} for i in instances]
+ )
+ for i, attribution in enumerate(attributions):
+ outputs[i]["feature_attribution"] = lit_dtypes.FeatureSalience(
+ attribution.feature_importance()
+ )
+ return outputs
+
+ def input_spec(self) -> lit_types.Spec:
+ """Return a spec describing model inputs."""
+ return dict(self._input_types)
+
+ def output_spec(self) -> lit_types.Spec:
+ """Return a spec describing model outputs."""
+ output_spec_dict = dict(self._output_types)
+ if self.attribution_explainer:
+ output_spec_dict["feature_attribution"] = lit_types.FeatureSalience(
+ signed=True
+ )
+ return output_spec_dict
+
+ def _load_model(self, model: str):
+ """Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
+ Args:
+ model: Required. A string reference to a TensorFlow saved model directory.
+ Raises:
+ ValueError if the model has more than one input tensor or more than one output tensor.
+ """
+ self._loaded_model = tf.saved_model.load(model)
+ serving_default = self._loaded_model.signatures[
+ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ ]
+ _, self._kwargs_signature = serving_default.structured_input_signature
+ self._output_signature = serving_default.structured_outputs
+
+ if len(self._kwargs_signature) != 1:
+ raise ValueError("Please use a model with only one input tensor.")
+
+ if len(self._output_signature) != 1:
+ raise ValueError("Please use a model with only one output tensor.")
+
+ def _set_up_attribution_explainer(
+ self, model: str, attribution_method: str = "integrated_gradients"
+ ):
+ """Populates the attribution explainer attribute of the class.
+ Args:
+ model: Required. A string reference to a TensorFlow saved model directory.
+ attribution_method:
+ Optional. A string to choose what attribution configuration to
+ set up the explainer with. Valid options are 'sampled_shapley'
+ or 'integrated_gradients'.
+ """
+ try:
+ import explainable_ai_sdk
+ from explainable_ai_sdk.metadata.tf.v2 import SavedModelMetadataBuilder
+ except ImportError:
+ logging.info(
+ "Skipping explanations because the Explainable AI SDK is not installed."
+ 'Please install the SDK using "pip install explainable-ai-sdk"'
+ )
+ return
+
+ builder = SavedModelMetadataBuilder(model)
+ builder.get_metadata()
+ builder.set_numeric_metadata(
+ self._input_tensor_name,
+ index_feature_mapping=list(self._input_types.keys()),
+ )
+ builder.save_metadata(model)
+ if attribution_method == "integrated_gradients":
+ explainer_config = explainable_ai_sdk.IntegratedGradientsConfig()
+ else:
+ explainer_config = explainable_ai_sdk.SampledShapleyConfig()
+
+ self._attribution_explainer = explainable_ai_sdk.load_model_from_local_path(
+ model, explainer_config
+ )
+ self._load_model(model)
+
+
+def create_lit_dataset(
+ dataset: pd.DataFrame,
+ column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+) -> lit_dataset.Dataset:
+ """Creates a LIT Dataset object.
+ Args:
+ dataset:
+ Required. A Pandas DataFrame that includes feature column names and data.
+ column_types:
+ Required. An OrderedDict of string names matching the columns of the dataset
+ as the key, and the associated LitType of the column.
+ Returns:
+ A LIT Dataset object that has the data from the dataset provided.
+ """
+ return _VertexLitDataset(dataset, column_types)
+
+
+def create_lit_model_from_endpoint(
+ endpoint: Union[str, aiplatform.Endpoint],
+ input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ model_id: Optional[str] = None,
+) -> lit_model.Model:
+ """Creates a LIT Model object.
+ Args:
+ model:
+ Required. The name of the Endpoint resource or an Endpoint instance.
+ Endpoint name format: ``projects/{project}/locations/{location}/endpoints/{endpoint}``
+ input_types:
+ Required. An OrderedDict of string names matching the features of the model
+ as the key, and the associated LitType of the feature.
+ output_types:
+ Required. An OrderedDict of string names matching the labels of the model
+ as the key, and the associated LitType of the label.
+ model_id:
+ Optional. A string of the specific model in the endpoint to create the
+ LIT model from. If this is not set, any usable model in the endpoint is
+ used to create the LIT model.
+ Returns:
+ A LIT Model object that has the same functionality as the model provided.
+ """
+ return _EndpointLitModel(endpoint, input_types, output_types, model_id)
+
+
+def create_lit_model(
+ model: str,
+ input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ attribution_method: str = "sampled_shapley",
+) -> lit_model.Model:
+ """Creates a LIT Model object.
+ Args:
+ model:
+ Required. A string reference to a local TensorFlow saved model directory.
+ The model must have at most one input and one output tensor.
+ input_types:
+ Required. An OrderedDict of string names matching the features of the model
+ as the key, and the associated LitType of the feature.
+ output_types:
+ Required. An OrderedDict of string names matching the labels of the model
+ as the key, and the associated LitType of the label.
+ attribution_method:
+ Optional. A string to choose what attribution configuration to
+ set up the explainer with. Valid options are 'sampled_shapley'
+ or 'integrated_gradients'.
+ Returns:
+ A LIT Model object that has the same functionality as the model provided.
+ """
+ return _TensorFlowLitModel(model, input_types, output_types, attribution_method)
+
+
+def open_lit(
+ models: Dict[str, lit_model.Model],
+ datasets: Dict[str, lit_dataset.Dataset],
+ open_in_new_tab: bool = True,
+):
+ """Open LIT from the provided models and datasets.
+ Args:
+ models:
+ Required. A list of LIT models to open LIT with.
+ input_types:
+ Required. A lit of LIT datasets to open LIT with.
+ open_in_new_tab:
+ Optional. A boolean to choose if LIT open in a new tab or not.
+ Raises:
+ ImportError if LIT is not installed.
+ """
+ widget = notebook.LitWidget(models, datasets)
+ widget.render(open_in_new_tab=open_in_new_tab)
+
+
+def set_up_and_open_lit(
+ dataset: Union[pd.DataFrame, lit_dataset.Dataset],
+ column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
+ model: Union[str, lit_model.Model],
+ input_types: Union[List[str], Dict[str, lit_types.LitType]],
+ output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
+ attribution_method: str = "sampled_shapley",
+ open_in_new_tab: bool = True,
+) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
+ """Creates a LIT dataset and model and opens LIT.
+ Args:
+ dataset:
+ Required. A Pandas DataFrame that includes feature column names and data.
+ column_types:
+ Required. An OrderedDict of string names matching the columns of the dataset
+ as the key, and the associated LitType of the column.
+ model:
+ Required. A string reference to a TensorFlow saved model directory.
+ The model must have at most one input and one output tensor.
+ input_types:
+ Required. An OrderedDict of string names matching the features of the model
+ as the key, and the associated LitType of the feature.
+ output_types:
+ Required. An OrderedDict of string names matching the labels of the model
+ as the key, and the associated LitType of the label.
+ attribution_method:
+ Optional. A string to choose what attribution configuration to
+ set up the explainer with. Valid options are 'sampled_shapley'
+ or 'integrated_gradients'.
+ open_in_new_tab:
+ Optional. A boolean to choose if LIT open in a new tab or not.
+ Returns:
+ A Tuple of the LIT dataset and model created.
+ Raises:
+ ImportError if LIT or TensorFlow is not installed.
+ ValueError if the model doesn't have only 1 input and output tensor.
+ """
+ if not isinstance(dataset, lit_dataset.Dataset):
+ dataset = create_lit_dataset(dataset, column_types)
+
+ if not isinstance(model, lit_model.Model):
+ model = create_lit_model(
+ model, input_types, output_types, attribution_method=attribution_method
+ )
+
+ open_lit(
+ {"model": model},
+ {"dataset": dataset},
+ open_in_new_tab=open_in_new_tab,
+ )
+
+ return dataset, model
diff --git a/google/cloud/aiplatform/explain/metadata/__init__.py b/google/cloud/aiplatform/explain/metadata/__init__.py
new file mode 100644
index 0000000000..0e973c9a40
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/explain/metadata/metadata_builder.py b/google/cloud/aiplatform/explain/metadata/metadata_builder.py
new file mode 100644
index 0000000000..002317d508
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/metadata_builder.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Base abstract class for metadata builders."""
+
+import abc
+
+_ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()})
+
+
+class MetadataBuilder(_ABC):
+ """Abstract base class for metadata builders."""
+
+ @abc.abstractmethod
+ def get_metadata(self):
+ """Returns the current metadata as a dictionary."""
+
+ @abc.abstractmethod
+ def get_metadata_protobuf(self):
+ """Returns the current metadata as ExplanationMetadata protobuf"""
diff --git a/google/cloud/aiplatform/explain/metadata/tf/__init__.py b/google/cloud/aiplatform/explain/metadata/tf/__init__.py
new file mode 100644
index 0000000000..0e973c9a40
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/tf/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py b/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py
new file mode 100644
index 0000000000..0e973c9a40
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/tf/v1/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py
new file mode 100644
index 0000000000..c9fc2d0e22
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/tf/v1/saved_model_metadata_builder.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.protobuf import json_format
+from typing import Any, Dict, List, Optional
+
+from google.cloud.aiplatform.compat.types import explanation_metadata
+from google.cloud.aiplatform.explain.metadata import metadata_builder
+
+
+class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
+ """Metadata builder class that accepts a TF1 saved model."""
+
+ def __init__(
+ self,
+ model_path: str,
+ tags: Optional[List[str]] = None,
+ signature_name: Optional[str] = None,
+ outputs_to_explain: Optional[List[str]] = None,
+ ) -> None:
+ """Initializes a SavedModelMetadataBuilder object.
+
+ Args:
+ model_path:
+ Required. Local or GCS path to load the saved model from.
+ tags:
+ Optional. Tags to identify the model graph. If None or empty,
+ TensorFlow's default serving tag will be used.
+ signature_name:
+ Optional. Name of the signature to be explained. Inputs and
+ outputs of this signature will be written in the metadata. If not
+ provided, the default signature will be used.
+ outputs_to_explain:
+ Optional. List of output names to explain. Only single output is
+ supported for now. Hence, the list should contain one element.
+ This parameter is required if the model signature (provided via
+ signature_name) specifies multiple outputs.
+
+ Raises:
+ ValueError: If outputs_to_explain contains more than 1 element or
+ signature contains multiple outputs.
+ """
+ if outputs_to_explain:
+ if len(outputs_to_explain) > 1:
+ raise ValueError(
+ "Only one output is supported at the moment. "
+ f"Received: {outputs_to_explain}."
+ )
+ self._output_to_explain = next(iter(outputs_to_explain))
+
+ try:
+ import tensorflow.compat.v1 as tf
+ except ImportError:
+ raise ImportError(
+ "Tensorflow is not installed and is required to load saved model. "
+ 'Please install the SDK using "pip install "tensorflow>=1.15,<2.0""'
+ )
+
+ if not signature_name:
+ signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ self._tags = tags or [tf.saved_model.tag_constants.SERVING]
+ self._graph = tf.Graph()
+
+ with self.graph.as_default():
+ self._session = tf.Session(graph=self.graph)
+ self._metagraph_def = tf.saved_model.loader.load(
+ sess=self.session, tags=self._tags, export_dir=model_path
+ )
+ if signature_name not in self._metagraph_def.signature_def:
+ raise ValueError(
+ f"Serving sigdef key {signature_name} not in the signature def."
+ )
+ serving_sigdef = self._metagraph_def.signature_def[signature_name]
+ if not outputs_to_explain:
+ if len(serving_sigdef.outputs) > 1:
+ raise ValueError(
+ "The signature contains multiple outputs. Specify "
+ 'an output via "outputs_to_explain" parameter.'
+ )
+ self._output_to_explain = next(iter(serving_sigdef.outputs.keys()))
+
+ self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs)
+ self._outputs = _create_output_metadata_from_signature(
+ serving_sigdef.outputs, self._output_to_explain
+ )
+
+ @property
+ def graph(self) -> "tf.Graph": # noqa: F821
+ return self._graph
+
+ @property
+ def session(self) -> "tf.Session": # noqa: F821
+ return self._session
+
+ def get_metadata(self) -> Dict[str, Any]:
+ """Returns the current metadata as a dictionary.
+
+ Returns:
+ Json format of the explanation metadata.
+ """
+ return json_format.MessageToDict(self.get_metadata_protobuf()._pb)
+
+ def get_metadata_protobuf(self) -> explanation_metadata.ExplanationMetadata:
+ """Returns the current metadata as a Protobuf object.
+
+ Returns:
+ ExplanationMetadata object format of the explanation metadata.
+ """
+ return explanation_metadata.ExplanationMetadata(
+ inputs=self._inputs,
+ outputs=self._outputs,
+ )
+
+
+def _create_input_metadata_from_signature(
+ signature_inputs: Dict[str, "tf.Tensor"] # noqa: F821
+) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]:
+ """Creates InputMetadata from signature inputs.
+
+ Args:
+ signature_inputs:
+ Required. Inputs of the signature to be explained. If not provided,
+ the default signature will be used.
+
+ Returns:
+ Inferred input metadata from the model.
+ """
+ input_mds = {}
+ for key, tensor in signature_inputs.items():
+ input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata(
+ input_tensor_name=tensor.name
+ )
+ return input_mds
+
+
+def _create_output_metadata_from_signature(
+ signature_outputs: Dict[str, "tf.Tensor"], # noqa: F821
+ output_to_explain: Optional[str] = None,
+) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]:
+ """Creates OutputMetadata from signature inputs.
+
+ Args:
+ signature_outputs:
+ Required. Inputs of the signature to be explained. If not provided,
+ the default signature will be used.
+ output_to_explain:
+ Optional. Output name to explain.
+
+ Returns:
+ Inferred output metadata from the model.
+ """
+ output_mds = {}
+ for key, tensor in signature_outputs.items():
+ if not output_to_explain or output_to_explain == key:
+ output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata(
+ output_tensor_name=tensor.name
+ )
+ return output_mds
diff --git a/google/cloud/aiplatform/explain/metadata/tf/v2/__init__.py b/google/cloud/aiplatform/explain/metadata/tf/v2/__init__.py
new file mode 100644
index 0000000000..0e973c9a40
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/tf/v2/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py
new file mode 100644
index 0000000000..7d19e5680d
--- /dev/null
+++ b/google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.protobuf import json_format
+from typing import Optional, List, Dict, Any, Tuple
+
+from google.cloud.aiplatform.explain.metadata import metadata_builder
+from google.cloud.aiplatform.compat.types import explanation_metadata
+
+
+class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
+ """Class for generating metadata for a model built with TF 2.X Keras API."""
+
+ def __init__(
+ self,
+ model_path: str,
+ signature_name: Optional[str] = None,
+ outputs_to_explain: Optional[List[str]] = None,
+ **kwargs
+ ) -> None:
+ """Initializes a SavedModelMetadataBuilder object.
+
+ Args:
+ model_path:
+ Required. Local or GCS path to load the saved model from.
+ signature_name:
+ Optional. Name of the signature to be explained. Inputs and
+ outputs of this signature will be written in the metadata. If not
+ provided, the default signature will be used.
+ outputs_to_explain:
+ Optional. List of output names to explain. Only single output is
+ supported for now. Hence, the list should contain one element.
+ This parameter is required if the model signature (provided via
+ signature_name) specifies multiple outputs.
+ **kwargs:
+ Any keyword arguments to be passed to tf.saved_model.save() function.
+
+ Raises:
+ ValueError: If outputs_to_explain contains more than 1 element.
+ ImportError: If tf is not imported.
+ """
+ if outputs_to_explain and len(outputs_to_explain) > 1:
+ raise ValueError(
+ '"outputs_to_explain" can only contain 1 element.\n'
+ "Got: %s" % len(outputs_to_explain)
+ )
+ self._explain_output = outputs_to_explain
+ self._saved_model_args = kwargs
+
+ try:
+ import tensorflow as tf
+ except ImportError:
+ raise ImportError(
+ "Tensorflow is not installed and is required to load saved model. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[full]"'
+ )
+
+ if not signature_name:
+ signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ self._loaded_model = tf.saved_model.load(model_path)
+ self._inputs, self._outputs = self._infer_metadata_entries_from_model(
+ signature_name
+ )
+
+ def _infer_metadata_entries_from_model(
+ self, signature_name: str
+ ) -> Tuple[
+ Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata],
+ Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata],
+ ]:
+ """Infers metadata inputs and outputs.
+
+ Args:
+ signature_name:
+ Required. Name of the signature to be explained. Inputs and outputs of this signature will be written in the metadata. If not provided, the default signature will be used.
+
+ Returns:
+ Inferred input metadata and output metadata from the model.
+
+ Raises:
+ ValueError: If specified name is not found in signature outputs.
+ """
+
+ loaded_sig = self._loaded_model.signatures[signature_name]
+ _, input_sig = loaded_sig.structured_input_signature
+ output_sig = loaded_sig.structured_outputs
+ input_mds = {}
+ for name, tensor_spec in input_sig.items():
+ input_mds[name] = explanation_metadata.ExplanationMetadata.InputMetadata(
+ input_tensor_name=name,
+ modality=None if tensor_spec.dtype.is_floating else "categorical",
+ )
+
+ output_mds = {}
+ for name in output_sig:
+ if not self._explain_output or self._explain_output[0] == name:
+ output_mds[
+ name
+ ] = explanation_metadata.ExplanationMetadata.OutputMetadata(
+ output_tensor_name=name,
+ )
+ break
+ else:
+ raise ValueError(
+ "Specified output name cannot be found in given signature outputs."
+ )
+ return input_mds, output_mds
+
+ def get_metadata(self) -> Dict[str, Any]:
+ """Returns the current metadata as a dictionary.
+
+ Returns:
+ Json format of the explanation metadata.
+ """
+ return json_format.MessageToDict(self.get_metadata_protobuf()._pb)
+
+ def get_metadata_protobuf(self) -> explanation_metadata.ExplanationMetadata:
+ """Returns the current metadata as a Protobuf object.
+
+ Returns:
+ ExplanationMetadata object format of the explanation metadata.
+ """
+ return explanation_metadata.ExplanationMetadata(
+ inputs=self._inputs,
+ outputs=self._outputs,
+ )
diff --git a/google/cloud/aiplatform/featurestore/__init__.py b/google/cloud/aiplatform/featurestore/__init__.py
new file mode 100644
index 0000000000..883f72dd26
--- /dev/null
+++ b/google/cloud/aiplatform/featurestore/__init__.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform.featurestore.entity_type import EntityType
+from google.cloud.aiplatform.featurestore.feature import Feature
+from google.cloud.aiplatform.featurestore.featurestore import Featurestore
+
+__all__ = (
+ "EntityType",
+ "Feature",
+ "Featurestore",
+)
diff --git a/google/cloud/aiplatform/featurestore/entity_type.py b/google/cloud/aiplatform/featurestore/entity_type.py
new file mode 100644
index 0000000000..edd0c7433b
--- /dev/null
+++ b/google/cloud/aiplatform/featurestore/entity_type.py
@@ -0,0 +1,1540 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+import uuid
+
+from google.auth import credentials as auth_credentials
+from google.protobuf import field_mask_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform.compat.types import (
+ entity_type as gca_entity_type,
+ feature_selector as gca_feature_selector,
+ featurestore_service as gca_featurestore_service,
+ featurestore_online_service as gca_featurestore_online_service,
+ io as gca_io,
+)
+from google.cloud.aiplatform import featurestore
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import featurestore_utils, resource_manager_utils
+
+from google.cloud import bigquery
+
+_LOGGER = base.Logger(__name__)
+_ALL_FEATURE_IDS = "*"
+
+
+class EntityType(base.VertexAiResourceNounWithFutureManager):
+ """Managed entityType resource for Vertex AI."""
+
+ client_class = utils.FeaturestoreClientWithOverride
+
+ _resource_noun = "entityTypes"
+ _getter_method = "get_entity_type"
+ _list_method = "list_entity_types"
+ _delete_method = "delete_entity_type"
+ _parse_resource_name_method = "parse_entity_type_path"
+ _format_resource_name_method = "entity_type_path"
+
+ @staticmethod
+ def _resource_id_validator(resource_id: str):
+ """Validates resource ID.
+
+ Args:
+ resource_id(str):
+ The resource id to validate.
+ """
+ featurestore_utils.validate_id(resource_id)
+
+ def __init__(
+ self,
+ entity_type_name: str,
+ featurestore_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing managed entityType given an entityType resource name or an entity_type ID.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='projects/123/locations/us-central1/featurestores/my_featurestore_id/\
+ entityTypes/my_entity_type_id'
+ )
+ or
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+
+ Args:
+ entity_type_name (str):
+ Required. A fully-qualified entityType resource name or an entity_type ID.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id/entityTypes/my_entity_type_id"
+ or "my_entity_type_id" when project and location are initialized or passed, with featurestore_id passed.
+ featurestore_id (str):
+ Optional. Featurestore ID of an existing featurestore to retrieve entityType from,
+ when entity_type_name is passed as entity_type ID.
+ project (str):
+ Optional. Project to retrieve entityType from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve entityType from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this EntityType. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=entity_type_name,
+ )
+ self._gca_resource = self._get_gca_resource(
+ resource_name=entity_type_name,
+ parent_resource_name_fields={
+ featurestore.Featurestore._resource_noun: featurestore_id
+ }
+ if featurestore_id
+ else featurestore_id,
+ )
+
+ self._featurestore_online_client = self._instantiate_featurestore_online_client(
+ location=self.location,
+ credentials=credentials,
+ )
+
+ def _get_featurestore_name(self) -> str:
+ """Gets full qualified resource name of the managed featurestore in which this EntityType is."""
+ entity_type_name_components = self._parse_resource_name(self.resource_name)
+ return featurestore.Featurestore._format_resource_name(
+ project=entity_type_name_components["project"],
+ location=entity_type_name_components["location"],
+ featurestore=entity_type_name_components["featurestore"],
+ )
+
+ @property
+ def featurestore_name(self) -> str:
+ """Full qualified resource name of the managed featurestore in which this EntityType is."""
+ self.wait()
+ return self._get_featurestore_name()
+
+ def get_featurestore(self) -> "featurestore.Featurestore":
+ """Retrieves the managed featurestore in which this EntityType is.
+
+ Returns:
+ featurestore.Featurestore - The managed featurestore in which this EntityType is.
+ """
+ return featurestore.Featurestore(self.featurestore_name)
+
+ def _get_feature(self, feature_id: str) -> "featurestore.Feature":
+ """Retrieves an existing managed feature in this EntityType.
+
+ Args:
+ feature_id (str):
+ Required. The managed feature resource ID in this EntityType.
+ Returns:
+ featurestore.Feature - The managed feature resource object.
+ """
+ entity_type_name_components = self._parse_resource_name(self.resource_name)
+ return featurestore.Feature(
+ feature_name=featurestore.Feature._format_resource_name(
+ project=entity_type_name_components["project"],
+ location=entity_type_name_components["location"],
+ featurestore=entity_type_name_components["featurestore"],
+ entity_type=entity_type_name_components["entity_type"],
+ feature=feature_id,
+ )
+ )
+
+ def get_feature(self, feature_id: str) -> "featurestore.Feature":
+ """Retrieves an existing managed feature in this EntityType.
+
+ Args:
+ feature_id (str):
+ Required. The managed feature resource ID in this EntityType.
+ Returns:
+ featurestore.Feature - The managed feature resource object.
+ """
+ self.wait()
+ return self._get_feature(feature_id=feature_id)
+
+ def update(
+ self,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Sequence[Tuple[str, str]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Updates an existing managed entityType resource.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+ my_entity_type.update(
+ description='update my description',
+ )
+
+ Args:
+ description (str):
+ Optional. Description of the EntityType.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your EntityTypes.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Required. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+ Returns:
+ EntityType - The updated entityType resource object.
+ """
+ self.wait()
+ update_mask = list()
+
+ if description:
+ update_mask.append("description")
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_entity_type = gca_entity_type.EntityType(
+ name=self.resource_name,
+ description=description,
+ labels=labels,
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "entityType",
+ self,
+ )
+
+ update_entity_type_lro = self.api_client.update_entity_type(
+ entity_type=gapic_entity_type,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ timeout=update_request_timeout,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "entityType", self.__class__, update_entity_type_lro
+ )
+
+ update_entity_type_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("entityType", "updated", self)
+
+ return self
+
+ @classmethod
+ def list(
+ cls,
+ featurestore_name: str,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["EntityType"]:
+ """Lists existing managed entityType resources in a featurestore, given a featurestore resource name or a featurestore ID.
+
+ Example Usage:
+
+ my_entityTypes = aiplatform.EntityType.list(
+ featurestore_name='projects/123/locations/us-central1/featurestores/my_featurestore_id'
+ )
+ or
+ my_entityTypes = aiplatform.EntityType.list(
+ featurestore_name='my_featurestore_id'
+ )
+
+ Args:
+ featurestore_name (str):
+ Required. A fully-qualified featurestore resource name or a featurestore ID
+ of an existing featurestore to list entityTypes in.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id"
+ or "my_featurestore_id" when project and location are initialized or passed.
+ filter (str):
+ Optional. Lists the EntityTypes that match the filter expression. The
+ following filters are supported:
+
+ - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``,
+ ``>=``, and ``<=`` comparisons. Values must be in RFC
+ 3339 format.
+ - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``,
+ ``>=``, and ``<=`` comparisons. Values must be in RFC
+ 3339 format.
+ - ``labels``: Supports key-value equality as well as key
+ presence.
+
+ Examples:
+
+ - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"``
+ --> EntityTypes created or updated after
+ 2020-01-31T15:30:00.000000Z.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ EntityTypes having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any EntityType which has a label
+ with 'env' as the key.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for
+ descending.
+
+ Supported fields:
+
+ - ``entity_type_id``
+ - ``create_time``
+ - ``update_time``
+ project (str):
+ Optional. Project to list entityTypes in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to list entityTypes in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to list entityTypes. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[EntityType] - A list of managed entityType resource objects
+ """
+
+ return cls._list(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=utils.full_resource_name(
+ resource_name=featurestore_name,
+ resource_noun=featurestore.Featurestore._resource_noun,
+ parse_resource_name_method=featurestore.Featurestore._parse_resource_name,
+ format_resource_name_method=featurestore.Featurestore._format_resource_name,
+ project=project,
+ location=location,
+ resource_id_validator=featurestore.Featurestore._resource_id_validator,
+ ),
+ )
+
+ def list_features(
+ self,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ ) -> List["featurestore.Feature"]:
+ """Lists existing managed feature resources in this EntityType.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+ my_entityType.list_features()
+
+ Args:
+ filter (str):
+ Optional. Lists the Features that match the filter expression. The
+ following filters are supported:
+
+ - ``value_type``: Supports = and != comparisons.
+ - ``create_time``: Supports =, !=, <, >, >=, and <=
+ comparisons. Values must be in RFC 3339 format.
+ - ``update_time``: Supports =, !=, <, >, >=, and <=
+ comparisons. Values must be in RFC 3339 format.
+ - ``labels``: Supports key-value equality as well as key
+ presence.
+
+ Examples:
+
+ - ``value_type = DOUBLE`` --> Features whose type is
+ DOUBLE.
+ - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"``
+ --> EntityTypes created or updated after
+ 2020-01-31T15:30:00.000000Z.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ Features having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any Feature which has a label with
+ 'env' as the key.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for
+ descending. Supported fields:
+
+ - ``feature_id``
+ - ``value_type``
+ - ``create_time``
+ - ``update_time``
+
+ Returns:
+ List[featurestore.Feature] - A list of managed feature resource objects.
+ """
+ self.wait()
+ return featurestore.Feature.list(
+ entity_type_name=self.resource_name,
+ filter=filter,
+ order_by=order_by,
+ )
+
+ @base.optional_sync()
+ def delete_features(
+ self,
+ feature_ids: List[str],
+ sync: bool = True,
+ ) -> None:
+ """Deletes feature resources in this EntityType given their feature IDs.
+ WARNING: This deletion is permanent.
+
+ Args:
+ feature_ids (List[str]):
+ Required. The list of feature IDs to be deleted.
+ sync (bool):
+ Optional. Whether to execute this deletion synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ """
+ features = []
+ for feature_id in feature_ids:
+ feature = self._get_feature(feature_id=feature_id)
+ feature.delete(sync=False)
+ features.append(feature)
+
+ for feature in features:
+ feature.wait()
+
+ @base.optional_sync()
+ def delete(self, sync: bool = True, force: bool = False) -> None:
+ """Deletes this EntityType resource. If force is set to True,
+ all features in this EntityType will be deleted prior to entityType deletion.
+
+ WARNING: This deletion is permanent.
+
+ Args:
+ force (bool):
+ If set to true, any Features for this
+ EntityType will also be deleted.
+ (Otherwise, the request will only work
+ if the EntityType has no Features.)
+ sync (bool):
+ Whether to execute this deletion synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ Raises:
+ FailedPrecondition: If features are created in this EntityType and force = False.
+ """
+ _LOGGER.log_action_start_against_resource("Deleting", "", self)
+ lro = getattr(self.api_client, self._delete_method)(
+ name=self.resource_name, force=force
+ )
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Delete", "", self.__class__, lro
+ )
+ lro.result()
+ _LOGGER.log_action_completed_against_resource("deleted.", "", self)
+
+ @classmethod
+ @base.optional_sync()
+ def create(
+ cls,
+ entity_type_id: str,
+ featurestore_name: str,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Creates an EntityType resource in a Featurestore.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType.create(
+ entity_type_id='my_entity_type_id',
+ featurestore_name='projects/123/locations/us-central1/featurestores/my_featurestore_id'
+ )
+ or
+ my_entity_type = aiplatform.EntityType.create(
+ entity_type_id='my_entity_type_id',
+ featurestore_name='my_featurestore_id',
+ )
+
+ Args:
+ entity_type_id (str):
+ Required. The ID to use for the EntityType, which will
+ become the final component of the EntityType's resource
+ name.
+
+ This value may be up to 60 characters, and valid characters
+ are ``[a-z0-9_]``. The first character cannot be a number.
+
+ The value must be unique within a featurestore.
+ featurestore_name (str):
+ Required. A fully-qualified featurestore resource name or a featurestore ID
+ of an existing featurestore to create EntityType in.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id"
+ or "my_featurestore_id" when project and location are initialized or passed.
+ description (str):
+ Optional. Description of the EntityType.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your EntityTypes.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one EntityType
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in if `featurestore_name` is passed an featurestore ID.
+ If not set, project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in if `featurestore_name` is passed an featurestore ID.
+ If not set, location set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ Returns:
+ EntityType - entity_type resource object
+
+ """
+
+ featurestore_name = utils.full_resource_name(
+ resource_name=featurestore_name,
+ resource_noun=featurestore.Featurestore._resource_noun,
+ parse_resource_name_method=featurestore.Featurestore._parse_resource_name,
+ format_resource_name_method=featurestore.Featurestore._format_resource_name,
+ project=project,
+ location=location,
+ resource_id_validator=featurestore.Featurestore._resource_id_validator,
+ )
+
+ featurestore_name_components = featurestore.Featurestore._parse_resource_name(
+ featurestore_name
+ )
+
+ gapic_entity_type = gca_entity_type.EntityType()
+
+ if labels:
+ utils.validate_labels(labels)
+ gapic_entity_type.labels = labels
+
+ if description:
+ gapic_entity_type.description = description
+
+ api_client = cls._instantiate_client(
+ location=featurestore_name_components["location"],
+ credentials=credentials,
+ )
+
+ created_entity_type_lro = api_client.create_entity_type(
+ parent=featurestore_name,
+ entity_type=gapic_entity_type,
+ entity_type_id=entity_type_id,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_with_lro(cls, created_entity_type_lro)
+
+ created_entity_type = created_entity_type_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_entity_type, "entity_type")
+
+ entity_type_obj = cls(
+ entity_type_name=created_entity_type.name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return entity_type_obj
+
+ def create_feature(
+ self,
+ feature_id: str,
+ value_type: str,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> "featurestore.Feature":
+ """Creates a Feature resource in this EntityType.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+ my_feature = my_entity_type.create_feature(
+ feature_id='my_feature_id',
+ value_type='INT64',
+ )
+
+ Args:
+ feature_id (str):
+ Required. The ID to use for the Feature, which will become
+ the final component of the Feature's resource name, which is immutable.
+
+ This value may be up to 60 characters, and valid characters
+ are ``[a-z0-9_]``. The first character cannot be a number.
+
+ The value must be unique within an EntityType.
+ value_type (str):
+ Required. Immutable. Type of Feature value.
+ One of BOOL, BOOL_ARRAY, DOUBLE, DOUBLE_ARRAY, INT64, INT64_ARRAY, STRING, STRING_ARRAY, BYTES.
+ description (str):
+ Optional. Description of the Feature.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Features.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ featurestore.Feature - feature resource object
+
+ """
+ self.wait()
+ return featurestore.Feature.create(
+ feature_id=feature_id,
+ value_type=value_type,
+ entity_type_name=self.resource_name,
+ description=description,
+ labels=labels,
+ request_metadata=request_metadata,
+ sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
+
+ def _validate_and_get_create_feature_requests(
+ self,
+ feature_configs: Dict[str, Dict[str, Union[bool, int, Dict[str, str], str]]],
+ ) -> List[gca_featurestore_service.CreateFeatureRequest]:
+ """Validates feature_configs and get requests for batch feature creation
+
+ Args:
+ feature_configs (Dict[str, Dict[str, Union[bool, int, Dict[str, str], str]]]):
+ Required. A user defined Dict containing configurations for feature creation.
+
+ Returns:
+ List[gca_featurestore_service.CreateFeatureRequest] - requests for batch feature creation
+ """
+
+ requests = []
+ for feature_id, feature_config in feature_configs.items():
+ feature_config = featurestore_utils._FeatureConfig(
+ feature_id=feature_id,
+ value_type=feature_config.get(
+ "value_type", featurestore_utils._FEATURE_VALUE_TYPE_UNSPECIFIED
+ ),
+ description=feature_config.get("description", None),
+ labels=feature_config.get("labels", {}),
+ )
+ create_feature_request = feature_config.get_create_feature_request()
+ requests.append(create_feature_request)
+
+ return requests
+
+ @base.optional_sync(return_input_arg="self")
+ def batch_create_features(
+ self,
+ feature_configs: Dict[str, Dict[str, Union[bool, int, Dict[str, str], str]]],
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ) -> "EntityType":
+ """Batch creates Feature resources in this EntityType.
+
+ Example Usage:
+
+ my_entity_type = aiplatform.EntityType(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+ my_entity_type.batch_create_features(
+ feature_configs={
+ "my_feature_id1": {
+ "value_type": "INT64",
+ },
+ "my_feature_id2": {
+ "value_type": "BOOL",
+ },
+ "my_feature_id3": {
+ "value_type": "STRING",
+ },
+ }
+ )
+
+ Args:
+ feature_configs (Dict[str, Dict[str, Union[bool, int, Dict[str, str], str]]]):
+ Required. A user defined Dict containing configurations for feature creation.
+
+ The feature_configs Dict[str, Dict] i.e. {feature_id: feature_config} contains configuration for each creating feature:
+ Example:
+ feature_configs = {
+ "my_feature_id_1": feature_config_1,
+ "my_feature_id_2": feature_config_2,
+ "my_feature_id_3": feature_config_3,
+ }
+
+ Each feature_config requires "value_type", and optional "description", "labels":
+ Example:
+ feature_config_1 = {
+ "value_type": "INT64",
+ }
+ feature_config_2 = {
+ "value_type": "BOOL",
+ "description": "my feature id 2 description"
+ }
+ feature_config_3 = {
+ "value_type": "STRING",
+ "labels": {
+ "my key": "my value",
+ }
+ }
+
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ EntityType - entity_type resource object
+ """
+ create_feature_requests = self._validate_and_get_create_feature_requests(
+ feature_configs=feature_configs
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Batch creating features",
+ "entityType",
+ self,
+ )
+
+ batch_created_features_lro = self.api_client.batch_create_features(
+ parent=self.resource_name,
+ requests=create_feature_requests,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Batch create Features",
+ "entityType",
+ self.__class__,
+ batch_created_features_lro,
+ )
+
+ batch_created_features_lro.result()
+
+ _LOGGER.log_action_completed_against_resource(
+ "entityType", "Batch created features", self
+ )
+
+ return self
+
+ @staticmethod
+ def _validate_and_get_import_feature_values_request(
+ entity_type_name: str,
+ feature_ids: List[str],
+ feature_time: Union[str, datetime.datetime],
+ data_source: Union[gca_io.AvroSource, gca_io.BigQuerySource, gca_io.CsvSource],
+ feature_source_fields: Optional[Dict[str, str]] = None,
+ entity_id_field: Optional[str] = None,
+ disable_online_serving: Optional[bool] = None,
+ worker_count: Optional[int] = None,
+ ) -> gca_featurestore_service.ImportFeatureValuesRequest:
+ """Validates and get import feature values request.
+ Args:
+ entity_type_name (str):
+ Required. A fully-qualified entityType resource name.
+ feature_ids (List[str]):
+ Required. IDs of the Feature to import values
+ of. The Features must exist in the target
+ EntityType, or the request will fail.
+ feature_time (Union[str, datetime.datetime]):
+ Required. The feature_time can be one of:
+ - The source column that holds the Feature
+ timestamp for all Feature values in each entity.
+ - A single Feature timestamp for all entities
+ being imported. The timestamp must not have
+ higher than millisecond precision.
+ data_source (Union[gca_io.AvroSource, gca_io.BiqQuerySource, gca_io.CsvSource]):
+ Required. The data_source can be one of:
+ - AvroSource
+ - BiqQuerySource
+ - CsvSource
+ feature_source_fields (Dict[str, str]):
+ Optional. User defined dictionary to map ID of the Feature for importing values
+ of to the source column for getting the Feature values from.
+
+ Specify the features whose ID and source column are not the same.
+ If not provided, the source column need to be the same as the Feature ID.
+
+ Example:
+ feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3']
+
+ feature_source_fields = {
+ 'my_feature_id_1': 'my_feature_id_1_source_field',
+ }
+
+ Note:
+ The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field',
+ The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'.
+
+ entity_id_field (str):
+ Optional. Source column that holds entity IDs. If not provided, entity
+ IDs are extracted from the column named ``entity_id``.
+ disable_online_serving (bool):
+ Optional. If set, data will not be imported for online
+ serving. This is typically used for backfilling,
+ where Feature generation timestamps are not in
+ the timestamp range needed for online serving.
+ worker_count (int):
+ Optional. Specifies the number of workers that are used
+ to write data to the Featurestore. Consider the
+ online serving capacity that you require to
+ achieve the desired import throughput without
+ interfering with online serving. The value must
+ be positive, and less than or equal to 100. If
+ not set, defaults to using 1 worker. The low
+ count ensures minimal impact on online serving
+ performance.
+ Returns:
+ gca_featurestore_service.ImportFeatureValuesRequest - request message for importing feature values
+ Raises:
+ ValueError if data_source type is not supported
+ ValueError if feature_time type is not supported
+ """
+ feature_source_fields = feature_source_fields or {}
+ feature_specs = [
+ gca_featurestore_service.ImportFeatureValuesRequest.FeatureSpec(
+ id=feature_id, source_field=feature_source_fields.get(feature_id)
+ )
+ for feature_id in set(feature_ids)
+ ]
+
+ import_feature_values_request = (
+ gca_featurestore_service.ImportFeatureValuesRequest(
+ entity_type=entity_type_name,
+ feature_specs=feature_specs,
+ entity_id_field=entity_id_field,
+ disable_online_serving=disable_online_serving,
+ worker_count=worker_count,
+ )
+ )
+
+ if isinstance(data_source, gca_io.AvroSource):
+ import_feature_values_request.avro_source = data_source
+ elif isinstance(data_source, gca_io.BigQuerySource):
+ import_feature_values_request.bigquery_source = data_source
+ elif isinstance(data_source, gca_io.CsvSource):
+ import_feature_values_request.csv_source = data_source
+ else:
+ raise ValueError(
+ f"The type of `data_source` field should be: "
+ f"`gca_io.AvroSource`, `gca_io.BigQuerySource`, or `gca_io.CsvSource`, "
+ f"get {type(data_source)} instead. "
+ )
+
+ if isinstance(feature_time, str):
+ import_feature_values_request.feature_time_field = feature_time
+ elif isinstance(feature_time, datetime.datetime):
+ import_feature_values_request.feature_time = utils.get_timestamp_proto(
+ time=feature_time
+ )
+ else:
+ raise ValueError(
+ f"The type of `feature_time` field should be: `str` or `datetime.datetime`, "
+ f"get {type(feature_time)} instead. "
+ )
+
+ return import_feature_values_request
+
+ def _import_feature_values(
+ self,
+ import_feature_values_request: gca_featurestore_service.ImportFeatureValuesRequest,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ingest_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Imports Feature values into the Featurestore from a source storage.
+
+ Args:
+ import_feature_values_request (gca_featurestore_service.ImportFeatureValuesRequest):
+ Required. Request message for importing feature values.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ ingest_request_timeout (float):
+ Optional. The timeout for the ingest request in seconds.
+ Returns:
+ EntityType - The entityType resource object with imported feature values.
+ """
+ _LOGGER.log_action_start_against_resource(
+ "Importing",
+ "feature values",
+ self,
+ )
+
+ import_lro = self.api_client.import_feature_values(
+ request=import_feature_values_request,
+ metadata=request_metadata,
+ timeout=ingest_request_timeout,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Import", "feature values", self.__class__, import_lro
+ )
+
+ import_lro.result()
+
+ _LOGGER.log_action_completed_against_resource(
+ "feature values", "imported", self
+ )
+
+ return self
+
+ @base.optional_sync(return_input_arg="self")
+ def ingest_from_bq(
+ self,
+ feature_ids: List[str],
+ feature_time: Union[str, datetime.datetime],
+ bq_source_uri: str,
+ feature_source_fields: Optional[Dict[str, str]] = None,
+ entity_id_field: Optional[str] = None,
+ disable_online_serving: Optional[bool] = None,
+ worker_count: Optional[int] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ingest_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Ingest feature values from BigQuery.
+
+ Args:
+ feature_ids (List[str]):
+ Required. IDs of the Feature to import values
+ of. The Features must exist in the target
+ EntityType, or the request will fail.
+ feature_time (Union[str, datetime.datetime]):
+ Required. The feature_time can be one of:
+ - The source column that holds the Feature
+ timestamp for all Feature values in each entity.
+ - A single Feature timestamp for all entities
+ being imported. The timestamp must not have
+ higher than millisecond precision.
+ bq_source_uri (str):
+ Required. BigQuery URI to the input table.
+ Example:
+ 'bq://project.dataset.table_name'
+ feature_source_fields (Dict[str, str]):
+ Optional. User defined dictionary to map ID of the Feature for importing values
+ of to the source column for getting the Feature values from.
+
+ Specify the features whose ID and source column are not the same.
+ If not provided, the source column need to be the same as the Feature ID.
+
+ Example:
+ feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3']
+
+ feature_source_fields = {
+ 'my_feature_id_1': 'my_feature_id_1_source_field',
+ }
+
+ Note:
+ The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field',
+ The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'.
+
+ entity_id_field (str):
+ Optional. Source column that holds entity IDs. If not provided, entity
+ IDs are extracted from the column named ``entity_id``.
+ disable_online_serving (bool):
+ Optional. If set, data will not be imported for online
+ serving. This is typically used for backfilling,
+ where Feature generation timestamps are not in
+ the timestamp range needed for online serving.
+ worker_count (int):
+ Optional. Specifies the number of workers that are used
+ to write data to the Featurestore. Consider the
+ online serving capacity that you require to
+ achieve the desired import throughput without
+ interfering with online serving. The value must
+ be positive, and less than or equal to 100. If
+ not set, defaults to using 1 worker. The low
+ count ensures minimal impact on online serving
+ performance.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this import synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ ingest_request_timeout (float):
+ Optional. The timeout for the ingest request in seconds.
+
+ Returns:
+ EntityType - The entityType resource object with feature values imported.
+
+ """
+
+ bigquery_source = gca_io.BigQuerySource(input_uri=bq_source_uri)
+
+ import_feature_values_request = (
+ self._validate_and_get_import_feature_values_request(
+ entity_type_name=self.resource_name,
+ feature_ids=feature_ids,
+ feature_time=feature_time,
+ data_source=bigquery_source,
+ feature_source_fields=feature_source_fields,
+ entity_id_field=entity_id_field,
+ disable_online_serving=disable_online_serving,
+ worker_count=worker_count,
+ )
+ )
+
+ return self._import_feature_values(
+ import_feature_values_request=import_feature_values_request,
+ request_metadata=request_metadata,
+ ingest_request_timeout=ingest_request_timeout,
+ )
+
+ @base.optional_sync(return_input_arg="self")
+ def ingest_from_gcs(
+ self,
+ feature_ids: List[str],
+ feature_time: Union[str, datetime.datetime],
+ gcs_source_uris: Union[str, List[str]],
+ gcs_source_type: str,
+ feature_source_fields: Optional[Dict[str, str]] = None,
+ entity_id_field: Optional[str] = None,
+ disable_online_serving: Optional[bool] = None,
+ worker_count: Optional[int] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ingest_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Ingest feature values from GCS.
+
+ Args:
+ feature_ids (List[str]):
+ Required. IDs of the Feature to import values
+ of. The Features must exist in the target
+ EntityType, or the request will fail.
+ feature_time (Union[str, datetime.datetime]):
+ Required. The feature_time can be one of:
+ - The source column that holds the Feature
+ timestamp for all Feature values in each entity.
+ - A single Feature timestamp for all entities
+ being imported. The timestamp must not have
+ higher than millisecond precision.
+ gcs_source_uris (Union[str, List[str]]):
+ Required. Google Cloud Storage URI(-s) to the
+ input file(s). May contain wildcards. For more
+ information on wildcards, see
+ https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
+ Example:
+ ["gs://my_bucket/my_file_1.csv", "gs://my_bucket/my_file_2.csv"]
+ or
+ "gs://my_bucket/my_file.avro"
+ gcs_source_type (str):
+ Required. The type of the input file(s) provided by `gcs_source_uris`,
+ the value of gcs_source_type can only be either `csv`, or `avro`.
+ feature_source_fields (Dict[str, str]):
+ Optional. User defined dictionary to map ID of the Feature for importing values
+ of to the source column for getting the Feature values from.
+
+ Specify the features whose ID and source column are not the same.
+ If not provided, the source column need to be the same as the Feature ID.
+
+ Example:
+ feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3']
+
+ feature_source_fields = {
+ 'my_feature_id_1': 'my_feature_id_1_source_field',
+ }
+
+ Note:
+ The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field',
+ The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'.
+
+ entity_id_field (str):
+ Optional. Source column that holds entity IDs. If not provided, entity
+ IDs are extracted from the column named ``entity_id``.
+ disable_online_serving (bool):
+ Optional. If set, data will not be imported for online
+ serving. This is typically used for backfilling,
+ where Feature generation timestamps are not in
+ the timestamp range needed for online serving.
+ worker_count (int):
+ Optional. Specifies the number of workers that are used
+ to write data to the Featurestore. Consider the
+ online serving capacity that you require to
+ achieve the desired import throughput without
+ interfering with online serving. The value must
+ be positive, and less than or equal to 100. If
+ not set, defaults to using 1 worker. The low
+ count ensures minimal impact on online serving
+ performance.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this import synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ ingest_request_timeout (float):
+ Optional. The timeout for the ingest request in seconds.
+
+ Returns:
+ EntityType - The entityType resource object with feature values imported.
+
+ Raises:
+ ValueError if gcs_source_type is not supported.
+ """
+ if gcs_source_type not in featurestore_utils.GCS_SOURCE_TYPE:
+ raise ValueError(
+ "Only %s are supported gcs_source_type, not `%s`. "
+ % (
+ "`" + "`, `".join(featurestore_utils.GCS_SOURCE_TYPE) + "`",
+ gcs_source_type,
+ )
+ )
+
+ if isinstance(gcs_source_uris, str):
+ gcs_source_uris = [gcs_source_uris]
+ gcs_source = gca_io.GcsSource(uris=gcs_source_uris)
+
+ if gcs_source_type == "csv":
+ data_source = gca_io.CsvSource(gcs_source=gcs_source)
+ if gcs_source_type == "avro":
+ data_source = gca_io.AvroSource(gcs_source=gcs_source)
+
+ import_feature_values_request = (
+ self._validate_and_get_import_feature_values_request(
+ entity_type_name=self.resource_name,
+ feature_ids=feature_ids,
+ feature_time=feature_time,
+ data_source=data_source,
+ feature_source_fields=feature_source_fields,
+ entity_id_field=entity_id_field,
+ disable_online_serving=disable_online_serving,
+ worker_count=worker_count,
+ )
+ )
+
+ return self._import_feature_values(
+ import_feature_values_request=import_feature_values_request,
+ request_metadata=request_metadata,
+ ingest_request_timeout=ingest_request_timeout,
+ )
+
+ def ingest_from_df(
+ self,
+ feature_ids: List[str],
+ feature_time: Union[str, datetime.datetime],
+ df_source: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
+ feature_source_fields: Optional[Dict[str, str]] = None,
+ entity_id_field: Optional[str] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ingest_request_timeout: Optional[float] = None,
+ ) -> "EntityType":
+ """Ingest feature values from DataFrame.
+
+ Note:
+ Calling this method will automatically create and delete a temporary
+ bigquery dataset in the same GCP project, which will be used
+ as the intermediary storage for ingesting feature values
+ from dataframe to featurestore.
+
+ The call will return upon ingestion completes, where the
+ feature values will be ingested into the entity_type.
+
+ Args:
+ feature_ids (List[str]):
+ Required. IDs of the Feature to import values
+ of. The Features must exist in the target
+ EntityType, or the request will fail.
+ feature_time (Union[str, datetime.datetime]):
+ Required. The feature_time can be one of:
+ - The source column that holds the Feature
+ timestamp for all Feature values in each entity.
+
+ Note:
+ The dtype of the source column should be `datetime64`.
+
+ - A single Feature timestamp for all entities
+ being imported. The timestamp must not have
+ higher than millisecond precision.
+
+ Example:
+ feature_time = datetime.datetime(year=2022, month=1, day=1, hour=11, minute=59, second=59)
+ or
+ feature_time_str = datetime.datetime.now().isoformat(sep=" ", timespec="milliseconds")
+ feature_time = datetime.datetime.strptime(feature_time_str, "%Y-%m-%d %H:%M:%S.%f")
+
+ df_source (pd.DataFrame):
+ Required. Pandas DataFrame containing the source data for ingestion.
+ feature_source_fields (Dict[str, str]):
+ Optional. User defined dictionary to map ID of the Feature for importing values
+ of to the source column for getting the Feature values from.
+
+ Specify the features whose ID and source column are not the same.
+ If not provided, the source column need to be the same as the Feature ID.
+
+ Example:
+ feature_ids = ['my_feature_id_1', 'my_feature_id_2', 'my_feature_id_3']
+
+ feature_source_fields = {
+ 'my_feature_id_1': 'my_feature_id_1_source_field',
+ }
+
+ Note:
+ The source column of 'my_feature_id_1' is 'my_feature_id_1_source_field',
+ The source column of 'my_feature_id_2' is the ID of the feature, same for 'my_feature_id_3'.
+
+ entity_id_field (str):
+ Optional. Source column that holds entity IDs. If not provided, entity
+ IDs are extracted from the column named ``entity_id``.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ ingest_request_timeout (float):
+ Optional. The timeout for the ingest request in seconds.
+
+ Returns:
+ EntityType - The entityType resource object with feature values imported.
+
+ """
+ try:
+ import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
+ except ImportError:
+ raise ImportError(
+ f"Pyarrow is not installed. Please install pyarrow to use "
+ f"{self.ingest_from_df.__name__}"
+ )
+
+ bigquery_client = bigquery.Client(
+ project=self.project, credentials=self.credentials
+ )
+
+ self.wait()
+
+ feature_source_fields = feature_source_fields or {}
+ bq_schema = []
+ for feature_id in feature_ids:
+ feature_field_name = feature_source_fields.get(feature_id, feature_id)
+ feature_value_type = self.get_feature(feature_id).to_dict()["valueType"]
+ bq_schema_field = self._get_bq_schema_field(
+ feature_field_name, feature_value_type
+ )
+ bq_schema.append(bq_schema_field)
+
+ entity_type_name_components = self._parse_resource_name(self.resource_name)
+ featurestore_id, entity_type_id = (
+ entity_type_name_components["featurestore"],
+ entity_type_name_components["entity_type"],
+ )
+
+ temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace(
+ "-", "_"
+ )
+
+ project_id = resource_manager_utils.get_project_id(
+ project_number=entity_type_name_components["project"],
+ credentials=self.credentials,
+ )
+ temp_bq_dataset_id = f"{project_id}.{temp_bq_dataset_name}"[:1024]
+ temp_bq_table_id = f"{temp_bq_dataset_id}.{entity_type_id}"
+
+ temp_bq_dataset = bigquery.Dataset(dataset_ref=temp_bq_dataset_id)
+ temp_bq_dataset.location = self.location
+
+ temp_bq_dataset = bigquery_client.create_dataset(temp_bq_dataset)
+
+ try:
+
+ parquet_options = bigquery.format_options.ParquetOptions()
+ parquet_options.enable_list_inference = True
+
+ job_config = bigquery.LoadJobConfig(
+ schema=bq_schema,
+ source_format=bigquery.SourceFormat.PARQUET,
+ parquet_options=parquet_options,
+ )
+
+ job = bigquery_client.load_table_from_dataframe(
+ dataframe=df_source,
+ destination=temp_bq_table_id,
+ job_config=job_config,
+ )
+ job.result()
+
+ entity_type_obj = self.ingest_from_bq(
+ feature_ids=feature_ids,
+ feature_time=feature_time,
+ bq_source_uri=f"bq://{temp_bq_table_id}",
+ feature_source_fields=feature_source_fields,
+ entity_id_field=entity_id_field,
+ request_metadata=request_metadata,
+ ingest_request_timeout=ingest_request_timeout,
+ )
+
+ finally:
+ bigquery_client.delete_dataset(
+ dataset=temp_bq_dataset.dataset_id,
+ delete_contents=True,
+ )
+
+ return entity_type_obj
+
+ @staticmethod
+ def _get_bq_schema_field(
+ name: str, feature_value_type: str
+ ) -> bigquery.SchemaField:
+ """Helper method to get BigQuery Schema Field.
+
+ Args:
+ name (str):
+ Required. The name of the schema field, which can be either the feature_id,
+ or the field_name in BigQuery for the feature if different than the feature_id.
+ feature_value_type (str):
+ Required. The feature value_type.
+
+ Returns:
+ bigquery.SchemaField: bigquery.SchemaField
+ """
+ bq_data_type = (
+ utils.featurestore_utils.FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP[
+ feature_value_type
+ ]
+ )
+ bq_schema_field = bigquery.SchemaField(
+ name=name,
+ field_type=bq_data_type["field_type"],
+ mode=bq_data_type.get("mode") or "NULLABLE",
+ )
+ return bq_schema_field
+
+ @staticmethod
+ def _instantiate_featurestore_online_client(
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> utils.FeaturestoreOnlineServingClientWithOverride:
+ """Helper method to instantiates featurestore online client.
+
+ Args:
+ location (str): The location of this featurestore.
+ credentials (google.auth.credentials.Credentials):
+ Optional custom credentials to use when interacting with
+ the featurestore online client.
+ Returns:
+ utils.FeaturestoreOnlineServingClientWithOverride:
+ Initialized featurestore online client with optional overrides.
+ """
+ return initializer.global_config.create_client(
+ client_class=utils.FeaturestoreOnlineServingClientWithOverride,
+ credentials=credentials,
+ location_override=location,
+ )
+
+ def read(
+ self,
+ entity_ids: Union[str, List[str]],
+ feature_ids: Union[str, List[str]] = "*",
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ read_request_timeout: Optional[float] = None,
+ ) -> "pd.DataFrame": # noqa: F821 - skip check for undefined name 'pd'
+ """Reads feature values for given feature IDs of given entity IDs in this EntityType.
+
+ Args:
+ entity_ids (Union[str, List[str]]):
+ Required. ID for a specific entity, or a list of IDs of entities
+ to read Feature values of. The maximum number of IDs is 100 if a list.
+ feature_ids (Union[str, List[str]]):
+ Required. ID for a specific feature, or a list of IDs of Features in the EntityType
+ for reading feature values. Default to "*", where value of all features will be read.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ read_request_timeout (float):
+ Optional. The timeout for the read request in seconds.
+
+ Returns:
+ pd.DataFrame: entities' feature values in DataFrame
+ """
+ self.wait()
+ if isinstance(feature_ids, str):
+ feature_ids = [feature_ids]
+
+ feature_selector = gca_feature_selector.FeatureSelector(
+ id_matcher=gca_feature_selector.IdMatcher(ids=feature_ids)
+ )
+
+ if isinstance(entity_ids, str):
+ read_feature_values_request = (
+ gca_featurestore_online_service.ReadFeatureValuesRequest(
+ entity_type=self.resource_name,
+ entity_id=entity_ids,
+ feature_selector=feature_selector,
+ )
+ )
+ read_feature_values_response = (
+ self._featurestore_online_client.read_feature_values(
+ request=read_feature_values_request,
+ metadata=request_metadata,
+ timeout=read_request_timeout,
+ )
+ )
+ header = read_feature_values_response.header
+ entity_views = [read_feature_values_response.entity_view]
+ elif isinstance(entity_ids, list):
+ streaming_read_feature_values_request = (
+ gca_featurestore_online_service.StreamingReadFeatureValuesRequest(
+ entity_type=self.resource_name,
+ entity_ids=entity_ids,
+ feature_selector=feature_selector,
+ )
+ )
+ streaming_read_feature_values_responses = [
+ response
+ for response in self._featurestore_online_client.streaming_read_feature_values(
+ request=streaming_read_feature_values_request,
+ metadata=request_metadata,
+ timeout=read_request_timeout,
+ )
+ ]
+ header = streaming_read_feature_values_responses[0].header
+ entity_views = [
+ response.entity_view
+ for response in streaming_read_feature_values_responses[1:]
+ ]
+
+ feature_ids = [
+ feature_descriptor.id for feature_descriptor in header.feature_descriptors
+ ]
+
+ return self._construct_dataframe(
+ feature_ids=feature_ids,
+ entity_views=entity_views,
+ )
+
+ @staticmethod
+ def _construct_dataframe(
+ feature_ids: List[str],
+ entity_views: List[
+ gca_featurestore_online_service.ReadFeatureValuesResponse.EntityView
+ ],
+ ) -> "pd.DataFrame": # noqa: F821 - skip check for undefined name 'pd'
+ """Constructs a dataframe using the header and entity_views
+
+ Args:
+ feature_ids (List[str]):
+ Required. A list of feature ids corresponding to the feature values for each entity in entity_views.
+ entity_views (List[gca_featurestore_online_service.ReadFeatureValuesResponse.EntityView]):
+ Required. A list of Entity views with Feature values.
+ For each Entity view, it may be
+ the entity in the Featurestore if values for all
+ Features were requested, or a projection of the
+ entity in the Featurestore if values for only
+ some Features were requested.
+
+ Raises:
+ ImportError: If pandas is not installed when using this method.
+
+ Returns:
+ pd.DataFrame - entities feature values in DataFrame
+ )
+ """
+
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ f"Pandas is not installed. Please install pandas to use "
+ f"{EntityType._construct_dataframe.__name__}"
+ )
+
+ data = []
+ for entity_view in entity_views:
+ entity_data = {"entity_id": entity_view.entity_id}
+ for feature_id, feature_data in zip(feature_ids, entity_view.data):
+ if feature_data._pb.HasField("value"):
+ value_type = feature_data.value._pb.WhichOneof("value")
+ feature_value = getattr(feature_data.value, value_type)
+ if hasattr(feature_value, "values"):
+ feature_value = feature_value.values
+ entity_data[feature_id] = feature_value
+ else:
+ entity_data[feature_id] = None
+ data.append(entity_data)
+
+ return pd.DataFrame(data=data, columns=["entity_id"] + feature_ids)
diff --git a/google/cloud/aiplatform/featurestore/feature.py b/google/cloud/aiplatform/featurestore/feature.py
new file mode 100644
index 0000000000..7a7fc0f29a
--- /dev/null
+++ b/google/cloud/aiplatform/featurestore/feature.py
@@ -0,0 +1,644 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from google.auth import credentials as auth_credentials
+from google.protobuf import field_mask_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform.compat.types import feature as gca_feature
+from google.cloud.aiplatform import featurestore
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import featurestore_utils
+
+_LOGGER = base.Logger(__name__)
+
+
+class Feature(base.VertexAiResourceNounWithFutureManager):
+ """Managed feature resource for Vertex AI."""
+
+ client_class = utils.FeaturestoreClientWithOverride
+
+ _resource_noun = "features"
+ _getter_method = "get_feature"
+ _list_method = "list_features"
+ _delete_method = "delete_feature"
+ _parse_resource_name_method = "parse_feature_path"
+ _format_resource_name_method = "feature_path"
+
+ @staticmethod
+ def _resource_id_validator(resource_id: str):
+ """Validates resource ID.
+
+ Args:
+ resource_id(str):
+ The resource id to validate.
+ """
+ featurestore_utils.validate_feature_id(resource_id)
+
+ def __init__(
+ self,
+ feature_name: str,
+ featurestore_id: Optional[str] = None,
+ entity_type_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing managed feature given a feature resource name or a feature ID.
+
+ Example Usage:
+
+ my_feature = aiplatform.Feature(
+ feature_name='projects/123/locations/us-central1/featurestores/my_featurestore_id/\
+ entityTypes/my_entity_type_id/features/my_feature_id'
+ )
+ or
+ my_feature = aiplatform.Feature(
+ feature_name='my_feature_id',
+ featurestore_id='my_featurestore_id',
+ entity_type_id='my_entity_type_id',
+ )
+
+ Args:
+ feature_name (str):
+ Required. A fully-qualified feature resource name or a feature ID.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id/entityTypes/my_entity_type_id/features/my_feature_id"
+ or "my_feature_id" when project and location are initialized or passed, with featurestore_id and entity_type_id passed.
+ featurestore_id (str):
+ Optional. Featurestore ID of an existing featurestore to retrieve feature from,
+ when feature_name is passed as Feature ID.
+ entity_type_id (str):
+ Optional. EntityType ID of an existing entityType to retrieve feature from,
+ when feature_name is passed as Feature ID.
+ The EntityType must exist in the Featurestore if provided by the featurestore_id.
+ project (str):
+ Optional. Project to retrieve feature from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve feature from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Feature. Overrides
+ credentials set in aiplatform.init.
+ Raises:
+ ValueError: If only one of featurestore_id or entity_type_id is provided.
+ """
+
+ if bool(featurestore_id) != bool(entity_type_id):
+ raise ValueError(
+ "featurestore_id and entity_type_id must both be provided or ommitted."
+ )
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=feature_name,
+ )
+ self._gca_resource = self._get_gca_resource(
+ resource_name=feature_name,
+ parent_resource_name_fields={
+ featurestore.Featurestore._resource_noun: featurestore_id,
+ featurestore.EntityType._resource_noun: entity_type_id,
+ }
+ if featurestore_id
+ else featurestore_id,
+ )
+
+ def _get_featurestore_name(self) -> str:
+ """Gets full qualified resource name of the managed featurestore in which this Feature is."""
+ feature_path_components = self._parse_resource_name(self.resource_name)
+ return featurestore.Featurestore._format_resource_name(
+ project=feature_path_components["project"],
+ location=feature_path_components["location"],
+ featurestore=feature_path_components["featurestore"],
+ )
+
+ @property
+ def featurestore_name(self) -> str:
+ """Full qualified resource name of the managed featurestore in which this Feature is."""
+ self.wait()
+ return self._get_featurestore_name()
+
+ def get_featurestore(self) -> "featurestore.Featurestore":
+ """Retrieves the managed featurestore in which this Feature is.
+
+ Returns:
+ featurestore.Featurestore - The managed featurestore in which this Feature is.
+ """
+ return featurestore.Featurestore(featurestore_name=self.featurestore_name)
+
+ def _get_entity_type_name(self) -> str:
+ """Gets full qualified resource name of the managed entityType in which this Feature is."""
+ feature_path_components = self._parse_resource_name(self.resource_name)
+ return featurestore.EntityType._format_resource_name(
+ project=feature_path_components["project"],
+ location=feature_path_components["location"],
+ featurestore=feature_path_components["featurestore"],
+ entity_type=feature_path_components["entity_type"],
+ )
+
+ @property
+ def entity_type_name(self) -> str:
+ """Full qualified resource name of the managed entityType in which this Feature is."""
+ self.wait()
+ return self._get_entity_type_name()
+
+ def get_entity_type(self) -> "featurestore.EntityType":
+ """Retrieves the managed entityType in which this Feature is.
+
+ Returns:
+ featurestore.EntityType - The managed entityType in which this Feature is.
+ """
+ return featurestore.EntityType(entity_type_name=self.entity_type_name)
+
+ def update(
+ self,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "Feature":
+ """Updates an existing managed feature resource.
+
+ Example Usage:
+
+ my_feature = aiplatform.Feature(
+ feature_name='my_feature_id',
+ featurestore_id='my_featurestore_id',
+ entity_type_id='my_entity_type_id',
+ )
+ my_feature.update(
+ description='update my description',
+ )
+
+ Args:
+ description (str):
+ Optional. Description of the Feature.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Features.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ Feature - The updated feature resource object.
+ """
+ self.wait()
+ update_mask = list()
+
+ if description:
+ update_mask.append("description")
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_feature = gca_feature.Feature(
+ name=self.resource_name,
+ description=description,
+ labels=labels,
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "feature",
+ self,
+ )
+
+ update_feature_lro = self.api_client.update_feature(
+ feature=gapic_feature,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ timeout=update_request_timeout,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "feature", self.__class__, update_feature_lro
+ )
+
+ update_feature_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("feature", "updated", self)
+
+ return self
+
+ @classmethod
+ def list(
+ cls,
+ entity_type_name: str,
+ featurestore_id: Optional[str] = None,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["Feature"]:
+ """Lists existing managed feature resources in an entityType, given an entityType resource name or an entity_type ID.
+
+ Example Usage:
+
+ my_features = aiplatform.Feature.list(
+ entity_type_name='projects/123/locations/us-central1/featurestores/my_featurestore_id/\
+ entityTypes/my_entity_type_id'
+ )
+ or
+ my_features = aiplatform.Feature.list(
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+
+ Args:
+ entity_type_name (str):
+ Required. A fully-qualified entityType resource name or an entity_type ID of an existing entityType
+ to list features in. The EntityType must exist in the Featurestore if provided by the featurestore_id.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id/entityTypes/my_entity_type_id"
+ or "my_entity_type_id" when project and location are initialized or passed, with featurestore_id passed.
+ featurestore_id (str):
+ Optional. Featurestore ID of an existing featurestore to list features in,
+ when entity_type_name is passed as entity_type ID.
+ filter (str):
+ Optional. Lists the Features that match the filter expression. The
+ following filters are supported:
+
+ - ``value_type``: Supports = and != comparisons.
+ - ``create_time``: Supports =, !=, <, >, >=, and <=
+ comparisons. Values must be in RFC 3339 format.
+ - ``update_time``: Supports =, !=, <, >, >=, and <=
+ comparisons. Values must be in RFC 3339 format.
+ - ``labels``: Supports key-value equality as well as key
+ presence.
+
+ Examples:
+
+ - ``value_type = DOUBLE`` --> Features whose type is
+ DOUBLE.
+ - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"``
+ --> EntityTypes created or updated after
+ 2020-01-31T15:30:00.000000Z.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ Features having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any Feature which has a label with
+ 'env' as the key.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for
+ descending. Supported fields:
+
+ - ``feature_id``
+ - ``value_type``
+ - ``create_time``
+ - ``update_time``
+ project (str):
+ Optional. Project to list features in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to list features in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to list features. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[Feature] - A list of managed feature resource objects
+ """
+
+ return cls._list(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=utils.full_resource_name(
+ resource_name=entity_type_name,
+ resource_noun=featurestore.EntityType._resource_noun,
+ parse_resource_name_method=featurestore.EntityType._parse_resource_name,
+ format_resource_name_method=featurestore.EntityType._format_resource_name,
+ parent_resource_name_fields={
+ featurestore.Featurestore._resource_noun: featurestore_id
+ }
+ if featurestore_id
+ else featurestore_id,
+ project=project,
+ location=location,
+ resource_id_validator=featurestore.EntityType._resource_id_validator,
+ ),
+ )
+
+ @classmethod
+ def search(
+ cls,
+ query: Optional[str] = None,
+ page_size: Optional[int] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["Feature"]:
+ """Searches existing managed Feature resources.
+
+ Example Usage:
+
+ my_features = aiplatform.Feature.search()
+
+ Args:
+ query (str):
+ Optional. Query string that is a conjunction of field-restricted
+ queries and/or field-restricted filters.
+ Field-restricted queries and filters can be combined
+ using ``AND`` to form a conjunction.
+
+ A field query is in the form FIELD:QUERY. This
+ implicitly checks if QUERY exists as a substring within
+ Feature's FIELD. The QUERY and the FIELD are converted
+ to a sequence of words (i.e. tokens) for comparison.
+ This is done by:
+
+ - Removing leading/trailing whitespace and tokenizing
+ the search value. Characters that are not one of
+ alphanumeric ``[a-zA-Z0-9]``, underscore ``_``, or
+ asterisk ``*`` are treated as delimiters for tokens.
+ ``*`` is treated as a wildcard that matches
+ characters within a token.
+ - Ignoring case.
+ - Prepending an asterisk to the first and appending an
+ asterisk to the last token in QUERY.
+
+ A QUERY must be either a singular token or a phrase. A
+ phrase is one or multiple words enclosed in double
+ quotation marks ("). With phrases, the order of the
+ words is important. Words in the phrase must be matching
+ in order and consecutively.
+
+ Supported FIELDs for field-restricted queries:
+
+ - ``feature_id``
+ - ``description``
+ - ``entity_type_id``
+
+ Examples:
+
+ - ``feature_id: foo`` --> Matches a Feature with ID
+ containing the substring ``foo`` (eg. ``foo``,
+ ``foofeature``, ``barfoo``).
+ - ``feature_id: foo*feature`` --> Matches a Feature
+ with ID containing the substring ``foo*feature`` (eg.
+ ``foobarfeature``).
+ - ``feature_id: foo AND description: bar`` --> Matches
+ a Feature with ID containing the substring ``foo``
+ and description containing the substring ``bar``.
+
+ Besides field queries, the following exact-match filters
+ are supported. The exact-match filters do not support
+ wildcards. Unlike field-restricted queries, exact-match
+ filters are case-sensitive.
+
+ - ``feature_id``: Supports = comparisons.
+ - ``description``: Supports = comparisons. Multi-token
+ filters should be enclosed in quotes.
+ - ``entity_type_id``: Supports = comparisons.
+ - ``value_type``: Supports = and != comparisons.
+ - ``labels``: Supports key-value equality as well as
+ key presence.
+ - ``featurestore_id``: Supports = comparisons.
+
+ Examples:
+
+ - ``description = "foo bar"`` --> Any Feature with
+ description exactly equal to ``foo bar``
+ - ``value_type = DOUBLE`` --> Features whose type is
+ DOUBLE.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ Features having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any Feature which has a label
+ with ``env`` as the key.
+
+ This corresponds to the ``query`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ page_size (int):
+ Optional. The maximum number of Features to return. The
+ service may return fewer than this value. If
+ unspecified, at most 100 Features will be
+ returned. The maximum value is 100; any value
+ greater than 100 will be coerced to 100.
+ project (str):
+ Optional. Project to list features in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to list features in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to list features. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[Feature] - A list of managed feature resource objects
+ """
+ resource = cls._empty_constructor(
+ project=project, location=location, credentials=credentials
+ )
+
+ # Fetch credentials once and re-use for all `_empty_constructor()` calls
+ creds = resource.credentials
+
+ search_features_request = {
+ "location": initializer.global_config.common_location_path(
+ project=project, location=location
+ ),
+ "query": query,
+ }
+
+ if page_size:
+ search_features_request["page_size"] = page_size
+
+ resource_list = (
+ resource.api_client.search_features(request=search_features_request) or []
+ )
+
+ return [
+ cls._construct_sdk_resource_from_gapic(
+ gapic_resource, project=project, location=location, credentials=creds
+ )
+ for gapic_resource in resource_list
+ ]
+
+ @classmethod
+ @base.optional_sync()
+ def create(
+ cls,
+ feature_id: str,
+ value_type: str,
+ entity_type_name: str,
+ featurestore_id: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> "Feature":
+ """Creates a Feature resource in an EntityType.
+
+ Example Usage:
+
+ my_feature = aiplatform.Feature.create(
+ feature_id='my_feature_id',
+ value_type='INT64',
+ entity_type_name='projects/123/locations/us-central1/featurestores/my_featurestore_id/\
+ entityTypes/my_entity_type_id'
+ )
+ or
+ my_feature = aiplatform.Feature.create(
+ feature_id='my_feature_id',
+ value_type='INT64',
+ entity_type_name='my_entity_type_id',
+ featurestore_id='my_featurestore_id',
+ )
+
+ Args:
+ feature_id (str):
+ Required. The ID to use for the Feature, which will become
+ the final component of the Feature's resource name, which is immutable.
+
+ This value may be up to 60 characters, and valid characters
+ are ``[a-z0-9_]``. The first character cannot be a number.
+
+ The value must be unique within an EntityType.
+ value_type (str):
+ Required. Immutable. Type of Feature value.
+ One of BOOL, BOOL_ARRAY, DOUBLE, DOUBLE_ARRAY, INT64, INT64_ARRAY, STRING, STRING_ARRAY, BYTES.
+ entity_type_name (str):
+ Required. A fully-qualified entityType resource name or an entity_type ID of an existing entityType
+ to create Feature in. The EntityType must exist in the Featurestore if provided by the featurestore_id.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id/entityTypes/my_entity_type_id"
+ or "my_entity_type_id" when project and location are initialized or passed, with featurestore_id passed.
+ featurestore_id (str):
+ Optional. Featurestore ID of an existing featurestore to create Feature in
+ if `entity_type_name` is passed an entity_type ID.
+ description (str):
+ Optional. Description of the Feature.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Features.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create Feature in if `entity_type_name` is passed an entity_type ID.
+ If not set, project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create Feature in if `entity_type_name` is passed an entity_type ID.
+ If not set, location set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create Features. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
+ Returns:
+ Feature - feature resource object
+
+ """
+ entity_type_name = utils.full_resource_name(
+ resource_name=entity_type_name,
+ resource_noun=featurestore.EntityType._resource_noun,
+ parse_resource_name_method=featurestore.EntityType._parse_resource_name,
+ format_resource_name_method=featurestore.EntityType._format_resource_name,
+ parent_resource_name_fields={
+ featurestore.Featurestore._resource_noun: featurestore_id
+ }
+ if featurestore_id
+ else featurestore_id,
+ project=project,
+ location=location,
+ resource_id_validator=featurestore.EntityType._resource_id_validator,
+ )
+ entity_type_name_components = featurestore.EntityType._parse_resource_name(
+ entity_type_name
+ )
+
+ feature_config = featurestore_utils._FeatureConfig(
+ feature_id=feature_id,
+ value_type=value_type,
+ description=description,
+ labels=labels,
+ )
+
+ create_feature_request = feature_config.get_create_feature_request()
+ create_feature_request.parent = entity_type_name
+
+ api_client = cls._instantiate_client(
+ location=entity_type_name_components["location"],
+ credentials=credentials,
+ )
+
+ created_feature_lro = api_client.create_feature(
+ request=create_feature_request,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_with_lro(cls, created_feature_lro)
+
+ created_feature = created_feature_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_feature, "feature")
+
+ feature_obj = cls(
+ feature_name=created_feature.name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return feature_obj
diff --git a/google/cloud/aiplatform/featurestore/featurestore.py b/google/cloud/aiplatform/featurestore/featurestore.py
new file mode 100644
index 0000000000..3bb5d44b80
--- /dev/null
+++ b/google/cloud/aiplatform/featurestore/featurestore.py
@@ -0,0 +1,1281 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+import uuid
+
+from google.auth import credentials as auth_credentials
+from google.protobuf import field_mask_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform.compat.types import (
+ feature_selector as gca_feature_selector,
+ featurestore as gca_featurestore,
+ featurestore_service as gca_featurestore_service,
+ io as gca_io,
+)
+from google.cloud.aiplatform import featurestore
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import featurestore_utils, resource_manager_utils
+
+from google.cloud import bigquery
+
+_LOGGER = base.Logger(__name__)
+
+
+class Featurestore(base.VertexAiResourceNounWithFutureManager):
+ """Managed featurestore resource for Vertex AI."""
+
+ client_class = utils.FeaturestoreClientWithOverride
+
+ _resource_noun = "featurestores"
+ _getter_method = "get_featurestore"
+ _list_method = "list_featurestores"
+ _delete_method = "delete_featurestore"
+ _parse_resource_name_method = "parse_featurestore_path"
+ _format_resource_name_method = "featurestore_path"
+
+ @staticmethod
+ def _resource_id_validator(resource_id: str):
+ """Validates resource ID.
+
+ Args:
+ resource_id(str):
+ The resource id to validate.
+ """
+ featurestore_utils.validate_id(resource_id)
+
+ def __init__(
+ self,
+ featurestore_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing managed featurestore given a featurestore resource name or a featurestore ID.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore(
+ featurestore_name='projects/123/locations/us-central1/featurestores/my_featurestore_id'
+ )
+ or
+ my_featurestore = aiplatform.Featurestore(
+ featurestore_name='my_featurestore_id'
+ )
+
+ Args:
+ featurestore_name (str):
+ Required. A fully-qualified featurestore resource name or a featurestore ID.
+ Example: "projects/123/locations/us-central1/featurestores/my_featurestore_id"
+ or "my_featurestore_id" when project and location are initialized or passed.
+ project (str):
+ Optional. Project to retrieve featurestore from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve featurestore from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Featurestore. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=featurestore_name,
+ )
+ self._gca_resource = self._get_gca_resource(resource_name=featurestore_name)
+
+ def get_entity_type(self, entity_type_id: str) -> "featurestore.EntityType":
+ """Retrieves an existing managed entityType in this Featurestore.
+
+ Args:
+ entity_type_id (str):
+ Required. The managed entityType resource ID in this Featurestore.
+ Returns:
+ featurestore.EntityType - The managed entityType resource object.
+ """
+ self.wait()
+ return self._get_entity_type(entity_type_id=entity_type_id)
+
+ def _get_entity_type(self, entity_type_id: str) -> "featurestore.EntityType":
+ """Retrieves an existing managed entityType in this Featurestore.
+
+ Args:
+ entity_type_id (str):
+ Required. The managed entityType resource ID in this Featurestore.
+ Returns:
+ featurestore.EntityType - The managed entityType resource object.
+ """
+ featurestore_name_components = self._parse_resource_name(self.resource_name)
+ return featurestore.EntityType(
+ entity_type_name=featurestore.EntityType._format_resource_name(
+ project=featurestore_name_components["project"],
+ location=featurestore_name_components["location"],
+ featurestore=featurestore_name_components["featurestore"],
+ entity_type=entity_type_id,
+ )
+ )
+
+ def update(
+ self,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Updates an existing managed featurestore resource.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore(
+ featurestore_name='my_featurestore_id',
+ )
+ my_featurestore.update(
+ labels={'update my key': 'update my value'},
+ )
+
+ Args:
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Featurestores.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ Featurestore - The updated featurestore resource object.
+ """
+
+ return self._update(
+ labels=labels,
+ request_metadata=request_metadata,
+ update_request_timeout=update_request_timeout,
+ )
+
+ # TODO(b/206818784): Add enable_online_store and disable_online_store methods
+ def update_online_store(
+ self,
+ fixed_node_count: int,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Updates the online store of an existing managed featurestore resource.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore(
+ featurestore_name='my_featurestore_id',
+ )
+ my_featurestore.update_online_store(
+ fixed_node_count=2,
+ )
+
+ Args:
+ fixed_node_count (int):
+ Required. Config for online serving resources, can only update the node count to >= 1.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ Featurestore - The updated featurestore resource object.
+ """
+ return self._update(
+ fixed_node_count=fixed_node_count,
+ request_metadata=request_metadata,
+ update_request_timeout=update_request_timeout,
+ )
+
+ def _update(
+ self,
+ labels: Optional[Dict[str, str]] = None,
+ fixed_node_count: Optional[int] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Updates an existing managed featurestore resource.
+
+ Args:
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Featurestores.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Feature
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ fixed_node_count (int):
+ Optional. Config for online serving resources, can only update the node count to >= 1.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ Featurestore - The updated featurestore resource object.
+ """
+ self.wait()
+ update_mask = list()
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ if fixed_node_count is not None:
+ update_mask.append("online_serving_config.fixed_node_count")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_featurestore = gca_featurestore.Featurestore(
+ name=self.resource_name,
+ labels=labels,
+ online_serving_config=gca_featurestore.Featurestore.OnlineServingConfig(
+ fixed_node_count=fixed_node_count
+ ),
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "featurestore",
+ self,
+ )
+
+ update_featurestore_lro = self.api_client.update_featurestore(
+ featurestore=gapic_featurestore,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ timeout=update_request_timeout,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "featurestore", self.__class__, update_featurestore_lro
+ )
+
+ update_featurestore_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("featurestore", "updated", self)
+
+ return self
+
+ def list_entity_types(
+ self,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ ) -> List["featurestore.EntityType"]:
+ """Lists existing managed entityType resources in this Featurestore.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore(
+ featurestore_name='my_featurestore_id',
+ )
+ my_featurestore.list_entity_types()
+
+ Args:
+ filter (str):
+ Optional. Lists the EntityTypes that match the filter expression. The
+ following filters are supported:
+
+ - ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``,
+ ``>=``, and ``<=`` comparisons. Values must be in RFC
+ 3339 format.
+ - ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``,
+ ``>=``, and ``<=`` comparisons. Values must be in RFC
+ 3339 format.
+ - ``labels``: Supports key-value equality as well as key
+ presence.
+
+ Examples:
+
+ - ``create_time > \"2020-01-31T15:30:00.000000Z\" OR update_time > \"2020-01-31T15:30:00.000000Z\"``
+ --> EntityTypes created or updated after
+ 2020-01-31T15:30:00.000000Z.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ EntityTypes having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any EntityType which has a label
+ with 'env' as the key.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for
+ descending.
+
+ Supported fields:
+
+ - ``entity_type_id``
+ - ``create_time``
+ - ``update_time``
+
+ Returns:
+ List[featurestore.EntityType] - A list of managed entityType resource objects.
+ """
+ self.wait()
+ return featurestore.EntityType.list(
+ featurestore_name=self.resource_name,
+ filter=filter,
+ order_by=order_by,
+ )
+
+ @base.optional_sync()
+ def delete_entity_types(
+ self,
+ entity_type_ids: List[str],
+ sync: bool = True,
+ force: bool = False,
+ ) -> None:
+ """Deletes entity_type resources in this Featurestore given their entity_type IDs.
+ WARNING: This deletion is permanent.
+
+ Args:
+ entity_type_ids (List[str]):
+ Required. The list of entity_type IDs to be deleted.
+ sync (bool):
+ Optional. Whether to execute this deletion synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ force (bool):
+ Optional. If force is set to True, all features in each entityType
+ will be deleted prior to entityType deletion. Default is False.
+ """
+ entity_types = []
+ for entity_type_id in entity_type_ids:
+ entity_type = self._get_entity_type(entity_type_id=entity_type_id)
+ entity_type.delete(force=force, sync=False)
+ entity_types.append(entity_type)
+
+ for entity_type in entity_types:
+ entity_type.wait()
+
+ @base.optional_sync()
+ def delete(self, sync: bool = True, force: bool = False) -> None:
+ """Deletes this Featurestore resource. If force is set to True,
+ all entityTypes in this Featurestore will be deleted prior to featurestore deletion,
+ and all features in each entityType will be deleted prior to each entityType deletion.
+
+ WARNING: This deletion is permanent.
+
+ Args:
+ force (bool):
+ If set to true, any EntityTypes and
+ Features for this Featurestore will also
+ be deleted. (Otherwise, the request will
+ only work if the Featurestore has no
+ EntityTypes.)
+ sync (bool):
+ Whether to execute this deletion synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ """
+ _LOGGER.log_action_start_against_resource("Deleting", "", self)
+ lro = getattr(self.api_client, self._delete_method)(
+ name=self.resource_name, force=force
+ )
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Delete", "", self.__class__, lro
+ )
+ lro.result()
+ _LOGGER.log_action_completed_against_resource("deleted.", "", self)
+
+ @classmethod
+ @base.optional_sync()
+ def create(
+ cls,
+ featurestore_id: str,
+ online_store_fixed_node_count: Optional[int] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ encryption_spec_key_name: Optional[str] = None,
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Creates a Featurestore resource.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore.create(
+ featurestore_id='my_featurestore_id',
+ )
+
+ Args:
+ featurestore_id (str):
+ Required. The ID to use for this Featurestore, which will
+ become the final component of the Featurestore's resource
+ name.
+
+ This value may be up to 60 characters, and valid characters
+ are ``[a-z0-9_]``. The first character cannot be a number.
+
+ The value must be unique within the project and location.
+ online_store_fixed_node_count (int):
+ Optional. Config for online serving resources.
+ When not specified, no fixed node count for online serving. The
+ number of nodes will not scale automatically but
+ can be scaled manually by providing different
+ values when updating.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Featurestore.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one
+ Featurestore(System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec (str):
+ Optional. Customer-managed encryption key
+ spec for data storage. If set, both of the
+ online and offline data storage will be secured
+ by this key.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
+ Returns:
+ Featurestore - Featurestore resource object
+
+ """
+ gapic_featurestore = gca_featurestore.Featurestore(
+ online_serving_config=gca_featurestore.Featurestore.OnlineServingConfig(
+ fixed_node_count=online_store_fixed_node_count
+ )
+ )
+
+ if labels:
+ utils.validate_labels(labels)
+ gapic_featurestore.labels = labels
+
+ if encryption_spec_key_name:
+ gapic_featurestore.encryption_spec = (
+ initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ )
+ )
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ created_featurestore_lro = api_client.create_featurestore(
+ parent=initializer.global_config.common_location_path(
+ project=project, location=location
+ ),
+ featurestore=gapic_featurestore,
+ featurestore_id=featurestore_id,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_with_lro(cls, created_featurestore_lro)
+
+ created_featurestore = created_featurestore_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_featurestore, "featurestore")
+
+ featurestore_obj = cls(
+ featurestore_name=created_featurestore.name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return featurestore_obj
+
+ def create_entity_type(
+ self,
+ entity_type_id: str,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> "featurestore.EntityType":
+ """Creates an EntityType resource in this Featurestore.
+
+ Example Usage:
+
+ my_featurestore = aiplatform.Featurestore.create(
+ featurestore_id='my_featurestore_id'
+ )
+ my_entity_type = my_featurestore.create_entity_type(
+ entity_type_id='my_entity_type_id',
+ )
+
+ Args:
+ entity_type_id (str):
+ Required. The ID to use for the EntityType, which will
+ become the final component of the EntityType's resource
+ name.
+
+ This value may be up to 60 characters, and valid characters
+ are ``[a-z0-9_]``. The first character cannot be a number.
+
+ The value must be unique within a featurestore.
+ description (str):
+ Optional. Description of the EntityType.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your EntityTypes.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one EntityType
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ featurestore.EntityType - EntityType resource object
+
+ """
+ self.wait()
+ return featurestore.EntityType.create(
+ entity_type_id=entity_type_id,
+ featurestore_name=self.resource_name,
+ description=description,
+ labels=labels,
+ request_metadata=request_metadata,
+ sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
+
+ def _batch_read_feature_values(
+ self,
+ batch_read_feature_values_request: gca_featurestore_service.BatchReadFeatureValuesRequest,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ serve_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Batch read Feature values from the Featurestore to a destination storage.
+
+ Args:
+ batch_read_feature_values_request (gca_featurestore_service.BatchReadFeatureValuesRequest):
+ Required. Request of batch read feature values.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ serve_request_timeout (float):
+ Optional. The timeout for the serve request in seconds.
+
+ Returns:
+ Featurestore: The featurestore resource object batch read feature values from.
+ """
+
+ _LOGGER.log_action_start_against_resource(
+ "Serving",
+ "feature values",
+ self,
+ )
+
+ batch_read_lro = self.api_client.batch_read_feature_values(
+ request=batch_read_feature_values_request,
+ metadata=request_metadata,
+ timeout=serve_request_timeout,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Serve", "feature values", self.__class__, batch_read_lro
+ )
+
+ batch_read_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("feature values", "served", self)
+
+ return self
+
+ @staticmethod
+ def _validate_and_get_read_instances(
+ read_instances_uri: str,
+ ) -> Union[gca_io.BigQuerySource, gca_io.CsvSource]:
+ """Gets read_instances
+
+ Args:
+ read_instances_uri (str):
+ Required. Read_instances_uri can be either BigQuery URI to an input table,
+ or Google Cloud Storage URI to a csv file.
+
+ Returns:
+ Union[gca_io.BigQuerySource, gca_io.CsvSource]:
+ BigQuery source or Csv source for read instances. The Csv source contains exactly 1 URI.
+
+ Raises:
+ ValueError if read_instances_uri does not start with 'bq://' or 'gs://'.
+ """
+ if not (
+ read_instances_uri.startswith("bq://")
+ or read_instances_uri.startswith("gs://")
+ ):
+ raise ValueError(
+ "The read_instances_uri should be a single uri starts with either 'bq://' or 'gs://'."
+ )
+
+ if read_instances_uri.startswith("bq://"):
+ return gca_io.BigQuerySource(input_uri=read_instances_uri)
+ if read_instances_uri.startswith("gs://"):
+ return gca_io.CsvSource(
+ gcs_source=gca_io.GcsSource(uris=[read_instances_uri])
+ )
+
+ def _validate_and_get_batch_read_feature_values_request(
+ self,
+ featurestore_name: str,
+ serving_feature_ids: Dict[str, List[str]],
+ destination: Union[
+ gca_io.BigQueryDestination,
+ gca_io.CsvDestination,
+ gca_io.TFRecordDestination,
+ ],
+ read_instances: Union[gca_io.BigQuerySource, gca_io.CsvSource],
+ pass_through_fields: Optional[List[str]] = None,
+ feature_destination_fields: Optional[Dict[str, str]] = None,
+ ) -> gca_featurestore_service.BatchReadFeatureValuesRequest:
+ """Validates and gets batch_read_feature_values_request
+
+ Args:
+ featurestore_name (str):
+ Required. A fully-qualified featurestore resource name.
+ serving_feature_ids (Dict[str, List[str]]):
+ Required. A user defined dictionary to define the entity_types and their features for batch serve/read.
+ The keys of the dictionary are the serving entity_type ids and
+ the values are lists of serving feature ids in each entity_type.
+
+ Example:
+ serving_feature_ids = {
+ 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'],
+ 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'],
+ }
+
+ destination (Union[gca_io.BigQueryDestination, gca_io.CsvDestination, gca_io.TFRecordDestination]):
+ Required. BigQuery destination, Csv destination or TFRecord destination.
+ read_instances (Union[gca_io.BigQuerySource, gca_io.CsvSource]):
+ Required. BigQuery source or Csv source for read instances.
+ The Csv source must contain exactly 1 URI.
+ pass_through_fields (List[str]):
+ Optional. When not empty, the specified fields in the
+ read_instances source will be joined as-is in the output,
+ in addition to those fields from the Featurestore Entity.
+
+ For BigQuery source, the type of the pass-through values
+ will be automatically inferred. For CSV source, the
+ pass-through values will be passed as opaque bytes.
+ feature_destination_fields (Dict[str, str]):
+ Optional. A user defined dictionary to map a feature's fully qualified resource name to
+ its destination field name. If the destination field name is not defined,
+ the feature ID will be used as its destination field name.
+
+ Example:
+ feature_destination_fields = {
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
+ }
+
+ Returns:
+ gca_featurestore_service.BatchReadFeatureValuesRequest: batch read feature values request
+ """
+ featurestore_name_components = self._parse_resource_name(featurestore_name)
+
+ feature_destination_fields = feature_destination_fields or {}
+
+ entity_type_specs = []
+ for entity_type_id, feature_ids in serving_feature_ids.items():
+ destination_feature_settings = []
+ for feature_id in feature_ids:
+ feature_resource_name = featurestore.Feature._format_resource_name(
+ project=featurestore_name_components["project"],
+ location=featurestore_name_components["location"],
+ featurestore=featurestore_name_components["featurestore"],
+ entity_type=entity_type_id,
+ feature=feature_id,
+ )
+
+ feature_destination_field = feature_destination_fields.get(
+ feature_resource_name
+ )
+ if feature_destination_field:
+ destination_feature_setting_proto = (
+ gca_featurestore_service.DestinationFeatureSetting(
+ feature_id=feature_id,
+ destination_field=feature_destination_field,
+ )
+ )
+ destination_feature_settings.append(
+ destination_feature_setting_proto
+ )
+
+ entity_type_spec = (
+ gca_featurestore_service.BatchReadFeatureValuesRequest.EntityTypeSpec(
+ entity_type_id=entity_type_id,
+ feature_selector=gca_feature_selector.FeatureSelector(
+ id_matcher=gca_feature_selector.IdMatcher(ids=feature_ids)
+ ),
+ settings=destination_feature_settings or None,
+ )
+ )
+ entity_type_specs.append(entity_type_spec)
+
+ batch_read_feature_values_request = (
+ gca_featurestore_service.BatchReadFeatureValuesRequest(
+ featurestore=featurestore_name,
+ entity_type_specs=entity_type_specs,
+ )
+ )
+
+ if isinstance(destination, gca_io.BigQueryDestination):
+ batch_read_feature_values_request.destination = (
+ gca_featurestore_service.FeatureValueDestination(
+ bigquery_destination=destination
+ )
+ )
+ elif isinstance(destination, gca_io.CsvDestination):
+ batch_read_feature_values_request.destination = (
+ gca_featurestore_service.FeatureValueDestination(
+ csv_destination=destination
+ )
+ )
+ elif isinstance(destination, gca_io.TFRecordDestination):
+ batch_read_feature_values_request.destination = (
+ gca_featurestore_service.FeatureValueDestination(
+ tfrecord_destination=destination
+ )
+ )
+
+ if isinstance(read_instances, gca_io.BigQuerySource):
+ batch_read_feature_values_request.bigquery_read_instances = read_instances
+ elif isinstance(read_instances, gca_io.CsvSource):
+ batch_read_feature_values_request.csv_read_instances = read_instances
+
+ if pass_through_fields is not None:
+ batch_read_feature_values_request.pass_through_fields = [
+ gca_featurestore_service.BatchReadFeatureValuesRequest.PassThroughField(
+ field_name=pass_through_field
+ )
+ for pass_through_field in pass_through_fields
+ ]
+
+ return batch_read_feature_values_request
+
+ @base.optional_sync(return_input_arg="self")
+ def batch_serve_to_bq(
+ self,
+ bq_destination_output_uri: str,
+ serving_feature_ids: Dict[str, List[str]],
+ read_instances_uri: str,
+ pass_through_fields: Optional[List[str]] = None,
+ feature_destination_fields: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ serve_request_timeout: Optional[float] = None,
+ sync: bool = True,
+ ) -> "Featurestore":
+ """Batch serves feature values to BigQuery destination
+
+ Args:
+ bq_destination_output_uri (str):
+ Required. BigQuery URI to the detination table.
+
+ Example:
+ 'bq://project.dataset.table_name'
+
+ It requires an existing BigQuery destination Dataset, under the same project as the Featurestore.
+
+ serving_feature_ids (Dict[str, List[str]]):
+ Required. A user defined dictionary to define the entity_types and their features for batch serve/read.
+ The keys of the dictionary are the serving entity_type ids and
+ the values are lists of serving feature ids in each entity_type.
+
+ Example:
+ serving_feature_ids = {
+ 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'],
+ 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'],
+ }
+
+ read_instances_uri (str):
+ Required. Read_instances_uri can be either BigQuery URI to an input table,
+ or Google Cloud Storage URI to a csv file.
+
+ Example:
+ 'bq://project.dataset.table_name'
+ or
+ "gs://my_bucket/my_file.csv"
+
+ Each read instance should consist of exactly one read timestamp
+ and one or more entity IDs identifying entities of the
+ corresponding EntityTypes whose Features are requested.
+
+ Each output instance contains Feature values of requested
+ entities concatenated together as of the read time.
+
+ An example read instance may be
+ ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z``.
+
+ An example output instance may be
+ ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z, foo_entity_feature1_value, bar_entity_feature2_value``.
+
+ Timestamp in each read instance must be millisecond-aligned.
+
+ The columns can be in any order.
+
+ Values in the timestamp column must use the RFC 3339 format,
+ e.g. ``2012-07-30T10:43:17.123Z``.
+
+ pass_through_fields (List[str]):
+ Optional. When not empty, the specified fields in the
+ read_instances source will be joined as-is in the output,
+ in addition to those fields from the Featurestore Entity.
+
+ For BigQuery source, the type of the pass-through values
+ will be automatically inferred. For CSV source, the
+ pass-through values will be passed as opaque bytes.
+
+ feature_destination_fields (Dict[str, str]):
+ Optional. A user defined dictionary to map a feature's fully qualified resource name to
+ its destination field name. If the destination field name is not defined,
+ the feature ID will be used as its destination field name.
+
+ Example:
+ feature_destination_fields = {
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
+ }
+ serve_request_timeout (float):
+ Optional. The timeout for the serve request in seconds.
+ Returns:
+ Featurestore: The featurestore resource object batch read feature values from.
+
+ Raises:
+ NotFound: if the BigQuery destination Dataset does not exist.
+ FailedPrecondition: if the BigQuery destination Dataset/Table is in a different project.
+ """
+ read_instances = self._validate_and_get_read_instances(read_instances_uri)
+
+ batch_read_feature_values_request = (
+ self._validate_and_get_batch_read_feature_values_request(
+ featurestore_name=self.resource_name,
+ serving_feature_ids=serving_feature_ids,
+ destination=gca_io.BigQueryDestination(
+ output_uri=bq_destination_output_uri
+ ),
+ feature_destination_fields=feature_destination_fields,
+ read_instances=read_instances,
+ pass_through_fields=pass_through_fields,
+ )
+ )
+
+ return self._batch_read_feature_values(
+ batch_read_feature_values_request=batch_read_feature_values_request,
+ request_metadata=request_metadata,
+ serve_request_timeout=serve_request_timeout,
+ )
+
+ @base.optional_sync(return_input_arg="self")
+ def batch_serve_to_gcs(
+ self,
+ gcs_destination_output_uri_prefix: str,
+ gcs_destination_type: str,
+ serving_feature_ids: Dict[str, List[str]],
+ read_instances_uri: str,
+ pass_through_fields: Optional[List[str]] = None,
+ feature_destination_fields: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ serve_request_timeout: Optional[float] = None,
+ ) -> "Featurestore":
+ """Batch serves feature values to GCS destination
+
+ Args:
+ gcs_destination_output_uri_prefix (str):
+ Required. Google Cloud Storage URI to output
+ directory. If the uri doesn't end with '/', a
+ '/' will be automatically appended. The
+ directory is created if it doesn't exist.
+
+ Example:
+ "gs://bucket/path/to/prefix"
+
+ gcs_destination_type (str):
+ Required. The type of the destination files(s),
+ the value of gcs_destination_type can only be either `csv`, or `tfrecord`.
+
+ For CSV format. Array Feature value types are not allowed in CSV format.
+
+ For TFRecord format.
+
+ Below are the mapping from Feature value type in
+ Featurestore to Feature value type in TFRecord:
+
+ ::
+
+ Value type in Featurestore | Value type in TFRecord
+ DOUBLE, DOUBLE_ARRAY | FLOAT_LIST
+ INT64, INT64_ARRAY | INT64_LIST
+ STRING, STRING_ARRAY, BYTES | BYTES_LIST
+ true -> byte_string("true"), false -> byte_string("false")
+ BOOL, BOOL_ARRAY (true, false) | BYTES_LIST
+
+ serving_feature_ids (Dict[str, List[str]]):
+ Required. A user defined dictionary to define the entity_types and their features for batch serve/read.
+ The keys of the dictionary are the serving entity_type ids and
+ the values are lists of serving feature ids in each entity_type.
+
+ Example:
+ serving_feature_ids = {
+ 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'],
+ 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'],
+ }
+
+ read_instances_uri (str):
+ Required. Read_instances_uri can be either BigQuery URI to an input table,
+ or Google Cloud Storage URI to a csv file.
+
+ Example:
+ 'bq://project.dataset.table_name'
+ or
+ "gs://my_bucket/my_file.csv"
+
+ Each read instance should consist of exactly one read timestamp
+ and one or more entity IDs identifying entities of the
+ corresponding EntityTypes whose Features are requested.
+
+ Each output instance contains Feature values of requested
+ entities concatenated together as of the read time.
+
+ An example read instance may be
+ ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z``.
+
+ An example output instance may be
+ ``foo_entity_id, bar_entity_id, 2020-01-01T10:00:00.123Z, foo_entity_feature1_value, bar_entity_feature2_value``.
+
+ Timestamp in each read instance must be millisecond-aligned.
+
+ The columns can be in any order.
+
+ Values in the timestamp column must use the RFC 3339 format,
+ e.g. ``2012-07-30T10:43:17.123Z``.
+
+ pass_through_fields (List[str]):
+ Optional. When not empty, the specified fields in the
+ read_instances source will be joined as-is in the output,
+ in addition to those fields from the Featurestore Entity.
+
+ For BigQuery source, the type of the pass-through values
+ will be automatically inferred. For CSV source, the
+ pass-through values will be passed as opaque bytes.
+
+ feature_destination_fields (Dict[str, str]):
+ Optional. A user defined dictionary to map a feature's fully qualified resource name to
+ its destination field name. If the destination field name is not defined,
+ the feature ID will be used as its destination field name.
+
+ Example:
+ feature_destination_fields = {
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
+ }
+ serve_request_timeout (float):
+ Optional. The timeout for the serve request in seconds.
+
+ Returns:
+ Featurestore: The featurestore resource object batch read feature values from.
+
+ Raises:
+ ValueError if gcs_destination_type is not supported.
+
+ """
+ destination = None
+ if gcs_destination_type not in featurestore_utils.GCS_DESTINATION_TYPE:
+ raise ValueError(
+ "Only %s are supported gcs_destination_type, not `%s`. "
+ % (
+ "`" + "`, `".join(featurestore_utils.GCS_DESTINATION_TYPE) + "`",
+ gcs_destination_type,
+ )
+ )
+
+ gcs_destination = gca_io.GcsDestination(
+ output_uri_prefix=gcs_destination_output_uri_prefix
+ )
+ if gcs_destination_type == "csv":
+ destination = gca_io.CsvDestination(gcs_destination=gcs_destination)
+ if gcs_destination_type == "tfrecord":
+ destination = gca_io.TFRecordDestination(gcs_destination=gcs_destination)
+
+ read_instances = self._validate_and_get_read_instances(read_instances_uri)
+
+ batch_read_feature_values_request = (
+ self._validate_and_get_batch_read_feature_values_request(
+ featurestore_name=self.resource_name,
+ serving_feature_ids=serving_feature_ids,
+ destination=destination,
+ feature_destination_fields=feature_destination_fields,
+ read_instances=read_instances,
+ pass_through_fields=pass_through_fields,
+ )
+ )
+
+ return self._batch_read_feature_values(
+ batch_read_feature_values_request=batch_read_feature_values_request,
+ request_metadata=request_metadata,
+ serve_request_timeout=serve_request_timeout,
+ )
+
+ def batch_serve_to_df(
+ self,
+ serving_feature_ids: Dict[str, List[str]],
+ read_instances_df: "pd.DataFrame", # noqa: F821 - skip check for undefined name 'pd'
+ pass_through_fields: Optional[List[str]] = None,
+ feature_destination_fields: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ serve_request_timeout: Optional[float] = None,
+ ) -> "pd.DataFrame": # noqa: F821 - skip check for undefined name 'pd'
+ """Batch serves feature values to pandas DataFrame
+
+ Note:
+ Calling this method will automatically create and delete a temporary
+ bigquery dataset in the same GCP project, which will be used
+ as the intermediary storage for batch serve feature values
+ from featurestore to dataframe.
+
+ Args:
+ serving_feature_ids (Dict[str, List[str]]):
+ Required. A user defined dictionary to define the entity_types and their features for batch serve/read.
+ The keys of the dictionary are the serving entity_type ids and
+ the values are lists of serving feature ids in each entity_type.
+
+ Example:
+ serving_feature_ids = {
+ 'my_entity_type_id_1': ['feature_id_1_1', 'feature_id_1_2'],
+ 'my_entity_type_id_2': ['feature_id_2_1', 'feature_id_2_2'],
+ }
+
+ read_instances_df (pd.DataFrame):
+ Required. Read_instances_df is a pandas DataFrame containing the read instances.
+
+ Each read instance should consist of exactly one read timestamp
+ and one or more entity IDs identifying entities of the
+ corresponding EntityTypes whose Features are requested.
+
+ Each output instance contains Feature values of requested
+ entities concatenated together as of the read time.
+
+ An example read_instances_df may be
+ pd.DataFrame(
+ data=[
+ {
+ "my_entity_type_id_1": "my_entity_type_id_1_entity_1",
+ "my_entity_type_id_2": "my_entity_type_id_2_entity_1",
+ "timestamp": "2020-01-01T10:00:00.123Z"
+ ],
+ )
+
+ An example batch_serve_output_df may be
+ pd.DataFrame(
+ data=[
+ {
+ "my_entity_type_id_1": "my_entity_type_id_1_entity_1",
+ "my_entity_type_id_2": "my_entity_type_id_2_entity_1",
+ "foo": "feature_id_1_1_feature_value",
+ "feature_id_1_2": "feature_id_1_2_feature_value",
+ "feature_id_2_1": "feature_id_2_1_feature_value",
+ "bar": "feature_id_2_2_feature_value",
+ "timestamp": "2020-01-01T10:00:00.123Z"
+ ],
+ )
+
+ Timestamp in each read instance must be millisecond-aligned.
+
+ The columns can be in any order.
+
+ Values in the timestamp column must use the RFC 3339 format,
+ e.g. ``2012-07-30T10:43:17.123Z``.
+
+ pass_through_fields (List[str]):
+ Optional. When not empty, the specified fields in the
+ read_instances source will be joined as-is in the output,
+ in addition to those fields from the Featurestore Entity.
+
+ For BigQuery source, the type of the pass-through values
+ will be automatically inferred. For CSV source, the
+ pass-through values will be passed as opaque bytes.
+
+ feature_destination_fields (Dict[str, str]):
+ Optional. A user defined dictionary to map a feature's fully qualified resource name to
+ its destination field name. If the destination field name is not defined,
+ the feature ID will be used as its destination field name.
+
+ Example:
+ feature_destination_fields = {
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id1/features/f_id11': 'foo',
+ 'projects/123/locations/us-central1/featurestores/fs_id/entityTypes/et_id2/features/f_id22': 'bar',
+ }
+ serve_request_timeout (float):
+ Optional. The timeout for the serve request in seconds.
+
+ Returns:
+ pd.DataFrame: The pandas DataFrame containing feature values from batch serving.
+
+ """
+ try:
+ from google.cloud import bigquery_storage
+ except ImportError:
+ raise ImportError(
+ f"Google-Cloud-Bigquery-Storage is not installed. Please install google-cloud-bigquery-storage to use "
+ f"{self.batch_serve_to_df.__name__}"
+ )
+
+ try:
+ import pyarrow # noqa: F401 - skip check for 'pyarrow' which is required when using 'google.cloud.bigquery'
+ except ImportError:
+ raise ImportError(
+ f"Pyarrow is not installed. Please install pyarrow to use "
+ f"{self.batch_serve_to_df.__name__}"
+ )
+
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ f"Pandas is not installed. Please install pandas to use "
+ f"{self.batch_serve_to_df.__name__}"
+ )
+
+ bigquery_client = bigquery.Client(
+ project=self.project, credentials=self.credentials
+ )
+
+ self.wait()
+ featurestore_name_components = self._parse_resource_name(self.resource_name)
+ featurestore_id = featurestore_name_components["featurestore"]
+
+ temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace(
+ "-", "_"
+ )
+
+ project_id = resource_manager_utils.get_project_id(
+ project_number=featurestore_name_components["project"],
+ credentials=self.credentials,
+ )
+ temp_bq_dataset_id = f"{project_id}.{temp_bq_dataset_name}"[:1024]
+ temp_bq_dataset = bigquery.Dataset(dataset_ref=temp_bq_dataset_id)
+ temp_bq_dataset.location = self.location
+ temp_bq_dataset = bigquery_client.create_dataset(temp_bq_dataset)
+
+ temp_bq_batch_serve_table_name = "batch_serve"
+ temp_bq_read_instances_table_name = "read_instances"
+ temp_bq_batch_serve_table_id = (
+ f"{temp_bq_dataset_id}.{temp_bq_batch_serve_table_name}"
+ )
+ temp_bq_read_instances_table_id = (
+ f"{temp_bq_dataset_id}.{temp_bq_read_instances_table_name}"
+ )
+
+ try:
+
+ job = bigquery_client.load_table_from_dataframe(
+ dataframe=read_instances_df, destination=temp_bq_read_instances_table_id
+ )
+ job.result()
+
+ self.batch_serve_to_bq(
+ bq_destination_output_uri=f"bq://{temp_bq_batch_serve_table_id}",
+ serving_feature_ids=serving_feature_ids,
+ read_instances_uri=f"bq://{temp_bq_read_instances_table_id}",
+ pass_through_fields=pass_through_fields,
+ feature_destination_fields=feature_destination_fields,
+ request_metadata=request_metadata,
+ serve_request_timeout=serve_request_timeout,
+ )
+
+ bigquery_storage_read_client = bigquery_storage.BigQueryReadClient(
+ credentials=self.credentials
+ )
+ read_session_proto = bigquery_storage_read_client.create_read_session(
+ parent=f"projects/{self.project}",
+ read_session=bigquery_storage.types.ReadSession(
+ table="projects/{project}/datasets/{dataset}/tables/{table}".format(
+ project=self.project,
+ dataset=temp_bq_dataset_name,
+ table=temp_bq_batch_serve_table_name,
+ ),
+ data_format=bigquery_storage.types.DataFormat.ARROW,
+ ),
+ )
+
+ frames = []
+ for stream in read_session_proto.streams:
+ reader = bigquery_storage_read_client.read_rows(stream.name)
+ for message in reader.rows().pages:
+ frames.append(message.to_dataframe())
+
+ finally:
+ bigquery_client.delete_dataset(
+ dataset=temp_bq_dataset.dataset_id,
+ delete_contents=True,
+ )
+
+ return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(frames)
diff --git a/google/cloud/aiplatform/gapic/schema/__init__.py b/google/cloud/aiplatform/gapic/schema/__init__.py
index e726749c77..5d31a70f1f 100644
--- a/google/cloud/aiplatform/gapic/schema/__init__.py
+++ b/google/cloud/aiplatform/gapic/schema/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from google.cloud.aiplatform.helpers import _decorators
+from google.cloud.aiplatform.utils.enhanced_library import _decorators
from google.cloud.aiplatform.v1.schema import predict
from google.cloud.aiplatform.v1.schema import trainingjob
from google.cloud.aiplatform.v1beta1.schema import predict as predict_v1beta1
diff --git a/google/cloud/aiplatform/helpers/__init__.py b/google/cloud/aiplatform/helpers/__init__.py
index 3f031f2bb4..e5fa8f665d 100644
--- a/google/cloud/aiplatform/helpers/__init__.py
+++ b/google/cloud/aiplatform/helpers/__init__.py
@@ -1,3 +1,21 @@
-from google.cloud.aiplatform.helpers import value_converter
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-__all__ = (value_converter,)
+from google.cloud.aiplatform.helpers import container_uri_builders
+
+get_prebuilt_prediction_container_uri = (
+ container_uri_builders.get_prebuilt_prediction_container_uri
+)
+
+__all__ = "get_prebuilt_prediction_container_uri"
diff --git a/google/cloud/aiplatform/helpers/container_uri_builders.py b/google/cloud/aiplatform/helpers/container_uri_builders.py
new file mode 100644
index 0000000000..6b49d3e230
--- /dev/null
+++ b/google/cloud/aiplatform/helpers/container_uri_builders.py
@@ -0,0 +1,109 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from google.cloud.aiplatform.constants import prediction
+from google.cloud.aiplatform import initializer
+
+
+def get_prebuilt_prediction_container_uri(
+ framework: str,
+ framework_version: str,
+ region: Optional[str] = None,
+ accelerator: str = "cpu",
+) -> str:
+ """
+ Get a Vertex AI pre-built prediction Docker container URI for
+ a given framework, version, region, and accelerator use.
+
+ Example usage:
+ ```
+ uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
+ framework="tensorflow",
+ framework_version="2.6",
+ accelerator="gpu"
+ )
+
+ model = aiplatform.Model.upload(
+ display_name="boston_housing_",
+ artifact_uri="gs://my-bucket/my-model/",
+ serving_container_image_uri=uri
+ )
+ ```
+
+ Args:
+ framework (str):
+ Required. The ML framework of the pre-built container. For example,
+ `"tensorflow"`, `"xgboost"`, or `"sklearn"`
+ framework_version (str):
+ Required. The version of the specified ML framework as a string.
+ region (str):
+ Optional. AI region or multi-region. Used to select the correct
+ Artifact Registry multi-region repository and reduce latency.
+ Must start with `"us"`, `"asia"` or `"europe"`.
+ Default is location set by `aiplatform.init()`.
+ accelerator (str):
+ Optional. The type of accelerator support provided by container. For
+ example: `"cpu"` or `"gpu"`
+ Default is `"cpu"`.
+
+ Returns:
+ uri (str):
+ A Vertex AI prediction container URI
+
+ Raises:
+ ValueError: If containers for provided framework are unavailable or the
+ container does not support the specified version, accelerator, or region.
+ """
+ URI_MAP = prediction._SERVING_CONTAINER_URI_MAP
+ DOCS_URI_MESSAGE = (
+ f"See {prediction._SERVING_CONTAINER_DOCUMENTATION_URL} "
+ "for complete list of supported containers"
+ )
+
+ # If region not provided, use initializer location
+ region = region or initializer.global_config.location
+ region = region.split("-", 1)[0]
+ framework = framework.lower()
+
+ if not URI_MAP.get(region):
+ raise ValueError(
+ f"Unsupported container region `{region}`, supported regions are "
+ f"{', '.join(URI_MAP.keys())}. "
+ f"{DOCS_URI_MESSAGE}"
+ )
+
+ if not URI_MAP[region].get(framework):
+ raise ValueError(
+ f"No containers found for framework `{framework}`. Supported frameworks are "
+ f"{', '.join(URI_MAP[region].keys())} {DOCS_URI_MESSAGE}"
+ )
+
+ if not URI_MAP[region][framework].get(accelerator):
+ raise ValueError(
+ f"{framework} containers do not support `{accelerator}` accelerator. Supported accelerators "
+ f"are {', '.join(URI_MAP[region][framework].keys())}. {DOCS_URI_MESSAGE}"
+ )
+
+ final_uri = URI_MAP[region][framework][accelerator].get(framework_version)
+
+ if not final_uri:
+ raise ValueError(
+ f"No serving container for `{framework}` version `{framework_version}` "
+ f"with accelerator `{accelerator}` found. Supported versions "
+ f"include {', '.join(URI_MAP[region][framework][accelerator].keys())}. {DOCS_URI_MESSAGE}"
+ )
+
+ return final_uri
diff --git a/google/cloud/aiplatform/hyperparameter_tuning.py b/google/cloud/aiplatform/hyperparameter_tuning.py
index a7a0e641cd..a43f1c39fd 100644
--- a/google/cloud/aiplatform/hyperparameter_tuning.py
+++ b/google/cloud/aiplatform/hyperparameter_tuning.py
@@ -101,7 +101,10 @@ class DoubleParameterSpec(_ParameterSpec):
_parameter_spec_value_key = "double_value_spec"
def __init__(
- self, min: float, max: float, scale: str,
+ self,
+ min: float,
+ max: float,
+ scale: str,
):
"""
Value specification for a parameter in ``DOUBLE`` type.
@@ -135,7 +138,10 @@ class IntegerParameterSpec(_ParameterSpec):
_parameter_spec_value_key = "integer_value_spec"
def __init__(
- self, min: int, max: int, scale: str,
+ self,
+ min: int,
+ max: int,
+ scale: str,
):
"""
Value specification for a parameter in ``INTEGER`` type.
@@ -169,7 +175,8 @@ class CategoricalParameterSpec(_ParameterSpec):
_parameter_spec_value_key = "categorical_value_spec"
def __init__(
- self, values: Sequence[str],
+ self,
+ values: Sequence[str],
):
"""Value specification for a parameter in ``CATEGORICAL`` type.
@@ -192,7 +199,9 @@ class DiscreteParameterSpec(_ParameterSpec):
_parameter_spec_value_key = "discrete_value_spec"
def __init__(
- self, values: Sequence[float], scale: str,
+ self,
+ values: Sequence[float],
+ scale: str,
):
"""Value specification for a parameter in ``DISCRETE`` type.
diff --git a/google/cloud/aiplatform/initializer.py b/google/cloud/aiplatform/initializer.py
index 18341bde46..9f0afd9e70 100644
--- a/google/cloud/aiplatform/initializer.py
+++ b/google/cloud/aiplatform/initializer.py
@@ -29,9 +29,11 @@
from google.auth.exceptions import GoogleAuthError
from google.cloud.aiplatform import compat
-from google.cloud.aiplatform import constants
+from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.metadata import metadata
+from google.cloud.aiplatform.utils import resource_manager_utils
+from google.cloud.aiplatform.tensorboard import tensorboard_resource
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec_compat,
@@ -57,6 +59,9 @@ def init(
location: Optional[str] = None,
experiment: Optional[str] = None,
experiment_description: Optional[str] = None,
+ experiment_tensorboard: Optional[
+ Union[str, tensorboard_resource.Tensorboard]
+ ] = None,
staging_bucket: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
@@ -67,8 +72,15 @@ def init(
project (str): The default project to use when making API calls.
location (str): The default location to use when making API calls. If not
set defaults to us-central-1.
- experiment (str): The experiment name.
- experiment_description (str): The description of the experiment.
+ experiment (str): Optional. The experiment name.
+ experiment_description (str): Optional. The description of the experiment.
+ experiment_tensorboard (Union[str, tensorboard_resource.Tensorboard]):
+ Optional. The Vertex AI TensorBoard instance, Tensorboard resource name,
+ or Tensorboard resource ID to use as a backing Tensorboard for the provided
+ experiment.
+
+ Example tensorboard resource name format:
+ "projects/123/locations/us-central1/tensorboards/456"
staging_bucket (str): The default staging bucket to use to stage artifacts
when making API calls. In the form gs://...
credentials (google.auth.credentials.Credentials): The default custom
@@ -83,28 +95,35 @@ def init(
resource is created.
If set, this resource and all sub-resources will be secured by this key.
+ Raises:
+ ValueError:
+ If experiment_description is provided but experiment is not.
+ If experiment_tensorboard is provided but expeirment is not.
"""
+ if experiment_description and experiment is None:
+ raise ValueError(
+ "Experiment needs to be set in `init` in order to add experiment descriptions."
+ )
+
+ if experiment_tensorboard and experiment is None:
+ raise ValueError(
+ "Experiment needs to be set in `init` in order to add experiment_tensorboard."
+ )
+
# reset metadata_service config if project or location is updated.
if (project and project != self._project) or (
location and location != self._location
):
- if metadata.metadata_service.experiment_name:
- logging.info("project/location updated, reset Metadata config.")
- metadata.metadata_service.reset()
+ if metadata._experiment_tracker.experiment_name:
+ logging.info("project/location updated, reset Experiment config.")
+ metadata._experiment_tracker.reset()
+
if project:
self._project = project
if location:
utils.validate_region(location)
self._location = location
- if experiment:
- metadata.metadata_service.set_experiment(
- experiment=experiment, description=experiment_description
- )
- if experiment_description and experiment is None:
- raise ValueError(
- "Experiment name needs to be set in `init` in order to add experiment descriptions."
- )
if staging_bucket:
self._staging_bucket = staging_bucket
if credentials:
@@ -112,6 +131,13 @@ def init(
if encryption_spec_key_name:
self._encryption_spec_key_name = encryption_spec_key_name
+ if experiment:
+ metadata._experiment_tracker.set_experiment(
+ experiment=experiment,
+ description=experiment_description,
+ backing_tensorboard=experiment_tensorboard,
+ )
+
def get_encryption_spec(
self,
encryption_spec_key_name: Optional[str],
@@ -147,6 +173,26 @@ def project(self) -> str:
if self._project:
return self._project
+ # Project is not set. Trying to get it from the environment.
+ # See https://github.com/googleapis/python-aiplatform/issues/852
+ # See https://github.com/googleapis/google-auth-library-python/issues/924
+ # TODO: Remove when google.auth.default() learns the
+ # CLOUD_ML_PROJECT_ID env variable or Vertex AI starts setting GOOGLE_CLOUD_PROJECT env variable.
+ project_number = os.environ.get("CLOUD_ML_PROJECT_ID")
+ if project_number:
+ # Try to convert project number to project ID which is more readable.
+ try:
+ project_id = resource_manager_utils.get_project_id(
+ project_number=project_number,
+ credentials=self.credentials,
+ )
+ return project_id
+ except Exception:
+ logging.getLogger(__name__).warning(
+ "Failed to convert project number to project ID.", exc_info=True
+ )
+ return project_number
+
project_not_found_exception_str = (
"Unable to find your project. Please provide a project ID by:"
"\n- Passing a constructor argument"
@@ -191,17 +237,26 @@ def encryption_spec_key_name(self) -> Optional[str]:
"""Default encryption spec key name, if provided."""
return self._encryption_spec_key_name
+ @property
+ def experiment_name(self) -> Optional[str]:
+ """Default experiment name, if provided."""
+ return metadata._experiment_tracker.experiment_name
+
def get_client_options(
- self, location_override: Optional[str] = None
+ self,
+ location_override: Optional[str] = None,
+ prediction_client: bool = False,
+ api_base_path_override: Optional[str] = None,
) -> client_options.ClientOptions:
"""Creates GAPIC client_options using location and type.
Args:
location_override (str):
- Set this parameter to get client options for a location different from
- location set by initializer. Must be a GCP region supported by AI
- Platform (Unified).
-
+ Optional. Set this parameter to get client options for a location different
+ from location set by initializer. Must be a GCP region supported by
+ Vertex AI.
+ prediction_client (str): Optional. flag to use a prediction endpoint.
+ api_base_path_override (str): Optional. Override default API base path.
Returns:
clients_options (google.api_core.client_options.ClientOptions):
A ClientOptions object set with regionalized API endpoint, i.e.
@@ -218,8 +273,14 @@ def get_client_options(
utils.validate_region(region)
+ service_base_path = api_base_path_override or (
+ constants.PREDICTION_API_BASE_PATH
+ if prediction_client
+ else constants.API_BASE_PATH
+ )
+
return client_options.ClientOptions(
- api_endpoint=f"{region}-{constants.API_BASE_PATH}"
+ api_endpoint=f"{region}-{service_base_path}"
)
def common_location_path(
@@ -251,17 +312,19 @@ def create_client(
credentials: Optional[auth_credentials.Credentials] = None,
location_override: Optional[str] = None,
prediction_client: bool = False,
+ api_base_path_override: Optional[str] = None,
) -> utils.VertexAiServiceClientWithOverride:
"""Instantiates a given VertexAiServiceClient with optional
overrides.
Args:
client_class (utils.VertexAiServiceClientWithOverride):
- (Required) An Vertex AI Service Client with optional overrides.
+ Required. A Vertex AI Service Client with optional overrides.
credentials (auth_credentials.Credentials):
- Custom auth credentials. If not provided will use the current config.
- location_override (str): Optional location override.
- prediction_client (str): Optional flag to use a prediction endpoint.
+ Optional. Custom auth credentials. If not provided will use the current config.
+ location_override (str): Optional. location override.
+ prediction_client (str): Optional. flag to use a prediction endpoint.
+ api_base_path_override (str): Optional. Override default api base path.
Returns:
client: Instantiated Vertex AI Service client with optional overrides
"""
@@ -276,7 +339,9 @@ def create_client(
kwargs = {
"credentials": credentials or self.credentials,
"client_options": self.get_client_options(
- location_override=location_override
+ location_override=location_override,
+ prediction_client=prediction_client,
+ api_base_path_override=api_base_path_override,
),
"client_info": client_info,
}
diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py
index c37530a78f..ab24afb171 100644
--- a/google/cloud/aiplatform/jobs.py
+++ b/google/cloud/aiplatform/jobs.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,45 +19,39 @@
import abc
import copy
-import sys
+import datetime
import time
-import logging
from google.cloud import storage
from google.cloud import bigquery
from google.auth import credentials as auth_credentials
from google.protobuf import duration_pb2 # type: ignore
+from google.rpc import status_pb2
from google.cloud import aiplatform
from google.cloud.aiplatform import base
-from google.cloud.aiplatform import compat
-from google.cloud.aiplatform import constants
-from google.cloud.aiplatform import initializer
-from google.cloud.aiplatform import hyperparameter_tuning
-from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.utils import source_utils
-from google.cloud.aiplatform.utils import worker_spec_utils
-
-from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
- batch_prediction_job_v1 as gca_bp_job_v1,
- batch_prediction_job_v1beta1 as gca_bp_job_v1beta1,
+ completion_stats as gca_completion_stats,
custom_job as gca_custom_job_compat,
- custom_job_v1beta1 as gca_custom_job_v1beta1,
- explanation_v1beta1 as gca_explanation_v1beta1,
+ explanation as gca_explanation_compat,
io as gca_io_compat,
- io_v1beta1 as gca_io_v1beta1,
job_state as gca_job_state,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
- hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1,
machine_resources as gca_machine_resources_compat,
- machine_resources_v1beta1 as gca_machine_resources_v1beta1,
+ manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat,
study as gca_study_compat,
)
+from google.cloud.aiplatform.constants import base as constants
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import hyperparameter_tuning
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import console_utils
+from google.cloud.aiplatform.utils import source_utils
+from google.cloud.aiplatform.utils import worker_spec_utils
+
-logging.basicConfig(level=logging.INFO, stream=sys.stdout)
_LOGGER = base.Logger(__name__)
_JOB_COMPLETE_STATES = (
@@ -72,8 +66,14 @@
gca_job_state.JobState.JOB_STATE_CANCELLED,
)
+# _block_until_complete wait times
+_JOB_WAIT_TIME = 5 # start at five seconds
+_LOG_WAIT_TIME = 5
+_MAX_WAIT_TIME = 60 * 5 # 5 minute wait
+_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration
+
-class _Job(base.VertexAiResourceNounWithFutureManager):
+class _Job(base.VertexAiStatefulResource):
"""Class that represents a general Job resource in Vertex AI.
Cannot be directly instantiated.
@@ -89,7 +89,9 @@ class _Job(base.VertexAiResourceNounWithFutureManager):
"""
client_class = utils.JobClientWithOverride
- _is_client_prediction_client = False
+
+ # Required by the done() method
+ _valid_done_states = _JOB_COMPLETE_STATES
def __init__(
self,
@@ -98,7 +100,7 @@ def __init__(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
- """Retrives Job subclass resource by calling a subclass-specific getter
+ """Retrieves Job subclass resource by calling a subclass-specific getter
method.
Args:
@@ -138,6 +140,27 @@ def state(self) -> gca_job_state.JobState:
return self._gca_resource.state
+ @property
+ def start_time(self) -> Optional[datetime.datetime]:
+ """Time when the Job resource entered the `JOB_STATE_RUNNING` for the
+ first time."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "start_time")
+
+ @property
+ def end_time(self) -> Optional[datetime.datetime]:
+ """Time when the Job resource entered the `JOB_STATE_SUCCEEDED`,
+ `JOB_STATE_FAILED`, or `JOB_STATE_CANCELLED` state."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "end_time")
+
+ @property
+ def error(self) -> Optional[status_pb2.Status]:
+ """Detailed error info for this Job resource. Only populated when the
+ Job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "error")
+
@property
@abc.abstractmethod
def _job_type(cls) -> str:
@@ -153,10 +176,24 @@ def _cancel_method(cls) -> str:
def _dashboard_uri(self) -> Optional[str]:
"""Helper method to compose the dashboard uri where job can be
viewed."""
- fields = utils.extract_fields_from_resource_name(self.resource_name)
- url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
+ fields = self._parse_resource_name(self.resource_name)
+ location = fields.pop("location")
+ project = fields.pop("project")
+ job = list(fields.values())[0]
+ url = f"https://console.cloud.google.com/ai/platform/locations/{location}/{self._job_type}/{job}?project={project}"
return url
+ def _log_job_state(self):
+ """Helper method to log job state."""
+ _LOGGER.info(
+ "%s %s current state:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ self._gca_resource.state,
+ )
+ )
+
def _block_until_complete(self):
"""Helper method to block and check on job until complete.
@@ -164,36 +201,19 @@ def _block_until_complete(self):
RuntimeError: If job failed or cancelled.
"""
- # Used these numbers so failures surface fast
- wait = 5 # start at five seconds
- log_wait = 5
- max_wait = 60 * 5 # 5 minute wait
- multiplier = 2 # scale wait by 2 every iteration
+ log_wait = _LOG_WAIT_TIME
previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
- _LOGGER.info(
- "%s %s current state:\n%s"
- % (
- self.__class__.__name__,
- self._gca_resource.name,
- self._gca_resource.state,
- )
- )
- log_wait = min(log_wait * multiplier, max_wait)
+ self._log_job_state()
+ log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
previous_time = current_time
- time.sleep(wait)
+ time.sleep(_JOB_WAIT_TIME)
+
+ self._log_job_state()
- _LOGGER.info(
- "%s %s current state:\n%s"
- % (
- self.__class__.__name__,
- self._gca_resource.name,
- self._gca_resource.state,
- )
- )
# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
@@ -267,6 +287,8 @@ class BatchPredictionJob(_Job):
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
+ _parse_resource_name_method = "parse_batch_prediction_job_path"
+ _format_resource_name_method = "batch_prediction_job_path"
def __init__(
self,
@@ -301,11 +323,38 @@ def __init__(
credentials=credentials,
)
+ @property
+ def output_info(
+ self,
+ ) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInfo]:
+ """Information describing the output of this job, including output location
+ into which prediction output is written.
+
+ This is only available for batch prediction jobs that have run successfully.
+ """
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.output_info
+
+ @property
+ def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
+ """Partial failures encountered. For example, single files that can't be read.
+ This field never exceeds 20 entries. Status details fields contain standard
+ GCP error details."""
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "partial_failures")
+
+ @property
+ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
+ """Statistics on completed and failed prediction instances."""
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "completion_stats")
+
@classmethod
def create(
cls,
+ # TODO(b/223262536): Make the job_display_name parameter optional in the next major release
job_display_name: str,
- model_name: str,
+ model_name: Union[str, "aiplatform.Model"],
instances_format: str = "jsonl",
predictions_format: str = "jsonl",
gcs_source: Optional[Union[str, Sequence[str]]] = None,
@@ -323,12 +372,14 @@ def create(
explanation_parameters: Optional[
"aiplatform.explain.ExplanationParameters"
] = None,
- labels: Optional[dict] = None,
+ labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ batch_size: Optional[int] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
@@ -337,29 +388,31 @@ def create(
Required. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
- model_name (str):
+ model_name (Union[str, aiplatform.Model]):
Required. A fully-qualified model resource name or model ID.
Example: "projects/123/locations/us-central1/models/456" or
"456" when project and location are initialized or passed.
+
+ Or an instance of aiplatform.Model.
instances_format (str):
- Required. The format in which instances are given, must be one
- of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
- or "file-list". Default is "jsonl" when using `gcs_source`. If a
- `bigquery_source` is provided, this is overriden to "bigquery".
+ Required. The format in which instances are provided. Must be one
+ of the formats listed in `Model.supported_input_storage_formats`.
+ Default is "jsonl" when using `gcs_source`. If a `bigquery_source`
+ is provided, this is overridden to "bigquery".
predictions_format (str):
- Required. The format in which Vertex AI gives the
- predictions, must be one of "jsonl", "csv", or "bigquery".
+ Required. The format in which Vertex AI outputs the
+ predictions, must be one of the formats specified in
+ `Model.supported_output_storage_formats`.
Default is "jsonl" when using `gcs_destination_prefix`. If a
- `bigquery_destination_prefix` is provided, this is overriden to
+ `bigquery_destination_prefix` is provided, this is overridden to
"bigquery".
gcs_source (Optional[Sequence[str]]):
Google Cloud Storage URI(-s) to your instances to run
batch prediction on. They must match `instances_format`.
- May contain wildcards. For more information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
+
bigquery_source (Optional[str]):
BigQuery URI to a table, up to 2000 characters long. For example:
- `projectId.bqDatasetId.bqTableId`
+ `bq://projectId.bqDatasetId.bqTableId`
gcs_destination_prefix (Optional[str]):
The Google Cloud Storage location of the directory where the
output is to be written to. In the given directory a new
@@ -383,24 +436,27 @@ def create(
which as value has ```google.rpc.Status`` `__
containing only ``code`` and ``message`` fields.
bigquery_destination_prefix (Optional[str]):
- The BigQuery project location where the output is to be
- written to. In the given project a new dataset is created
- with name
- ``prediction__`` where
- is made BigQuery-dataset-name compatible (for example, most
- special characters become underscores), and timestamp is in
- YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the
- dataset two tables will be created, ``predictions``, and
- ``errors``. If the Model has both ``instance`` and ``prediction``
- schemata defined then the tables have columns as follows:
- The ``predictions`` table contains instances for which the
- prediction succeeded, it has columns as per a concatenation
- of the Model's instance and prediction schemata. The
- ``errors`` table contains rows for which the prediction has
- failed, it has instance columns, as per the instance schema,
- followed by a single "errors" column, which as values has
- ```google.rpc.Status`` `__ represented as a STRUCT,
- and containing only ``code`` and ``message``.
+ The BigQuery URI to a project or table, up to 2000 characters long.
+ When only the project is specified, the Dataset and Table is created.
+ When the full table reference is specified, the Dataset must exist and
+ table must not exist. Accepted forms: ``bq://projectId`` or
+ ``bq://projectId.bqDatasetId`` or
+ ``bq://projectId.bqDatasetId.bqTableId``. If no Dataset is specified,
+ a new one is created with the name
+ ``prediction__``
+ where the table name is made BigQuery-dataset-name compatible
+ (for example, most special characters become underscores), and
+ timestamp is in YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601"
+ format. In the dataset two tables will be created, ``predictions``,
+ and ``errors``. If the Model has both ``instance`` and
+ ``prediction`` schemata defined then the tables have columns as
+ follows: The ``predictions`` table contains instances for which
+ the prediction succeeded, it has columns as per a concatenation
+ of the Model's instance and prediction schemata. The ``errors``
+ table contains rows for which the prediction has failed, it has
+ instance columns, as per the instance schema, followed by a single
+ "errors" column, which as values has ```google.rpc.Status`` `__
+ represented as a STRUCT, and containing only ``code`` and ``message``.
model_parameters (Optional[Dict]):
The parameters that govern the predictions. The schema of
the parameters may be specified via the Model's `parameters_schema_uri`.
@@ -452,8 +508,8 @@ def create(
a field of the `explanation_parameters` object is not populated, the
corresponding field of the `Model.explanation_parameters` object is inherited.
For more details, see `Ref docs `
- labels (Optional[dict]):
- The labels with user-defined metadata to organize your
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to organize your
BatchPredictionJobs. Label keys and values can be no longer than
64 characters (Unicode codepoints), can only contain lowercase
letters, numeric characters, underscores and dashes.
@@ -479,20 +535,36 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
-
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ batch_size (int):
+ Optional. The number of the records (e.g. instances) of the operation given in each batch
+ to a machine replica. Machine type, and size of a single record should be considered
+ when setting this parameter, higher value speeds up the batch operation's execution,
+ but too high value will result in a whole batch not fitting in a machine's memory,
+ and the whole operation will fail.
+ The default value is 64.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
+ if not job_display_name:
+ job_display_name = cls._generate_display_name()
utils.validate_display_name(job_display_name)
- model_name = utils.full_resource_name(
- resource_name=model_name,
- resource_noun="models",
- project=project,
- location=location,
- )
+ if labels:
+ utils.validate_labels(labels)
+
+ if isinstance(model_name, str):
+ model_name = utils.full_resource_name(
+ resource_name=model_name,
+ resource_noun="models",
+ parse_resource_name_method=aiplatform.Model._parse_resource_name,
+ format_resource_name_method=aiplatform.Model._format_resource_name,
+ project=project,
+ location=location,
+ )
# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
@@ -521,38 +593,28 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)
- gca_bp_job = gca_bp_job_compat
- gca_io = gca_io_compat
- gca_machine_resources = gca_machine_resources_compat
- select_version = compat.DEFAULT_VERSION
- if generate_explanation:
- gca_bp_job = gca_bp_job_v1beta1
- gca_io = gca_io_v1beta1
- gca_machine_resources = gca_machine_resources_v1beta1
- select_version = compat.V1BETA1
- gapic_batch_prediction_job = gca_bp_job.BatchPredictionJob()
+ gapic_batch_prediction_job = gca_bp_job_compat.BatchPredictionJob()
# Required Fields
gapic_batch_prediction_job.display_name = job_display_name
- gapic_batch_prediction_job.model = model_name
- input_config = gca_bp_job.BatchPredictionJob.InputConfig()
- output_config = gca_bp_job.BatchPredictionJob.OutputConfig()
+ input_config = gca_bp_job_compat.BatchPredictionJob.InputConfig()
+ output_config = gca_bp_job_compat.BatchPredictionJob.OutputConfig()
if bigquery_source:
input_config.instances_format = "bigquery"
- input_config.bigquery_source = gca_io.BigQuerySource()
+ input_config.bigquery_source = gca_io_compat.BigQuerySource()
input_config.bigquery_source.input_uri = bigquery_source
else:
input_config.instances_format = instances_format
- input_config.gcs_source = gca_io.GcsSource(
+ input_config.gcs_source = gca_io_compat.GcsSource(
uris=gcs_source if type(gcs_source) == list else [gcs_source]
)
if bigquery_destination_prefix:
output_config.predictions_format = "bigquery"
- output_config.bigquery_destination = gca_io.BigQueryDestination()
+ output_config.bigquery_destination = gca_io_compat.BigQueryDestination()
bq_dest_prefix = bigquery_destination_prefix
@@ -562,7 +624,7 @@ def create(
output_config.bigquery_destination.output_uri = bq_dest_prefix
else:
output_config.predictions_format = predictions_format
- output_config.gcs_destination = gca_io.GcsDestination(
+ output_config.gcs_destination = gca_io_compat.GcsDestination(
output_uri_prefix=gcs_destination_prefix
)
@@ -570,9 +632,10 @@ def create(
gapic_batch_prediction_job.output_config = output_config
# Optional Fields
- gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec(
- encryption_spec_key_name=encryption_spec_key_name,
- select_version=select_version,
+ gapic_batch_prediction_job.encryption_spec = (
+ initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ )
)
if model_parameters:
@@ -581,12 +644,12 @@ def create(
# Custom Compute
if machine_type:
- machine_spec = gca_machine_resources.MachineSpec()
+ machine_spec = gca_machine_resources_compat.MachineSpec()
machine_spec.machine_type = machine_type
machine_spec.accelerator_type = accelerator_type
machine_spec.accelerator_count = accelerator_count
- dedicated_resources = gca_machine_resources.BatchDedicatedResources()
+ dedicated_resources = gca_machine_resources_compat.BatchDedicatedResources()
dedicated_resources.machine_spec = machine_spec
dedicated_resources.starting_replica_count = starting_replica_count
@@ -594,7 +657,14 @@ def create(
gapic_batch_prediction_job.dedicated_resources = dedicated_resources
- gapic_batch_prediction_job.manual_batch_tuning_parameters = None
+ manual_batch_tuning_parameters = (
+ gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters()
+ )
+ manual_batch_tuning_parameters.batch_size = batch_size
+
+ gapic_batch_prediction_job.manual_batch_tuning_parameters = (
+ manual_batch_tuning_parameters
+ )
# User Labels
gapic_batch_prediction_job.labels = labels
@@ -604,67 +674,53 @@ def create(
gapic_batch_prediction_job.generate_explanation = generate_explanation
if explanation_metadata or explanation_parameters:
- gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec(
- metadata=explanation_metadata, parameters=explanation_parameters
+ gapic_batch_prediction_job.explanation_spec = (
+ gca_explanation_compat.ExplanationSpec(
+ metadata=explanation_metadata, parameters=explanation_parameters
+ )
)
- # TODO (b/174502913): Support private feature once released
-
- api_client = cls._instantiate_client(location=location, credentials=credentials)
+ empty_batch_prediction_job = cls._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
return cls._create(
- api_client=api_client,
- parent=initializer.global_config.common_location_path(
- project=project, location=location
- ),
- batch_prediction_job=gapic_batch_prediction_job,
+ empty_batch_prediction_job=empty_batch_prediction_job,
+ model_or_model_name=model_name,
+ gca_batch_prediction_job=gapic_batch_prediction_job,
generate_explanation=generate_explanation,
- project=project or initializer.global_config.project,
- location=location or initializer.global_config.location,
- credentials=credentials or initializer.global_config.credentials,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@classmethod
- @base.optional_sync()
+ @base.optional_sync(return_input_arg="empty_batch_prediction_job")
def _create(
cls,
- api_client: job_service_client.JobServiceClient,
- parent: str,
- batch_prediction_job: Union[
- gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob
- ],
+ empty_batch_prediction_job: "BatchPredictionJob",
+ model_or_model_name: Union[str, "aiplatform.Model"],
+ gca_batch_prediction_job: gca_bp_job_compat.BatchPredictionJob,
generate_explanation: bool,
- project: str,
- location: str,
- credentials: Optional[auth_credentials.Credentials],
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Args:
- api_client (dataset_service_client.DatasetServiceClient):
- Required. An instance of DatasetServiceClient with the correct api_endpoint
- already set based on user's preferences.
- batch_prediction_job (gca_bp_job.BatchPredictionJob):
+ empty_batch_prediction_job (BatchPredictionJob):
+ Required. BatchPredictionJob without _gca_resource populated.
+ model_or_model_name (Union[str, aiplatform.Model]):
+ Required. Required. A fully-qualified model resource name or
+ an instance of aiplatform.Model.
+ gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
generate_explanation (bool):
Required. Generate explanation along with the batch prediction
results.
- parent (str):
- Required. Also known as common location path, that usually contains the
- project and location that the user provided to the upstream method.
- Example: "projects/my-prj/locations/us-central1"
- project (str):
- Required. Project to upload this model to. Overrides project set in
- aiplatform.init.
- location (str):
- Required. Location to upload this model to. Overrides location set in
- aiplatform.init.
- credentials (Optional[auth_credentials.Credentials]):
- Custom credentials to use to upload this model. Overrides
- credentials set in aiplatform.init.
-
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
@@ -676,21 +732,33 @@ def _create(
by Vertex AI.
"""
# select v1beta1 if explain else use default v1
- if generate_explanation:
- api_client = api_client.select_version(compat.V1BETA1)
+
+ parent = initializer.global_config.common_location_path(
+ project=empty_batch_prediction_job.project,
+ location=empty_batch_prediction_job.location,
+ )
+
+ model_resource_name = (
+ model_or_model_name
+ if isinstance(model_or_model_name, str)
+ else model_or_model_name.resource_name
+ )
+
+ gca_batch_prediction_job.model = model_resource_name
+
+ api_client = empty_batch_prediction_job.api_client
_LOGGER.log_create_with_lro(cls)
gca_batch_prediction_job = api_client.create_batch_prediction_job(
- parent=parent, batch_prediction_job=batch_prediction_job
+ parent=parent,
+ batch_prediction_job=gca_batch_prediction_job,
+ timeout=create_request_timeout,
)
- batch_prediction_job = cls(
- batch_prediction_job_name=gca_batch_prediction_job.name,
- project=project,
- location=location,
- credentials=credentials,
- )
+ empty_batch_prediction_job._gca_resource = gca_batch_prediction_job
+
+ batch_prediction_job = empty_batch_prediction_job
_LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj")
@@ -729,6 +797,8 @@ def iter_outputs(
GCS or BQ output provided.
"""
+ self._assert_gca_resource_is_available()
+
if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
@@ -757,23 +827,27 @@ def iter_outputs(
# BigQuery Destination, return RowIterator
elif output_info.bigquery_output_dataset:
- # Build a BigQuery Client using the same credentials as JobServiceClient
- bq_client = bigquery.Client(
- project=self.project,
- credentials=self.api_client._transport._credentials,
- )
-
- # Format from service is `bq://projectId.bqDatasetId`
+ # Format of `bigquery_output_dataset` from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset
+ bq_table = output_info.bigquery_output_table
+
+ if not bq_table:
+ raise RuntimeError(
+ "A BigQuery table with predictions was not found, this "
+ f"might be due to errors. Visit {self._dashboard_uri()} for details."
+ )
if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]
- # # Split project ID and BQ dataset ID
- _, bq_dataset_id = bq_dataset.split(".", 1)
+ # Build a BigQuery Client using the same credentials as JobServiceClient
+ bq_client = bigquery.Client(
+ project=self.project,
+ credentials=self.api_client._transport._credentials,
+ )
row_iterator = bq_client.list_rows(
- table=f"{bq_dataset_id}.predictions", max_results=bq_max_results
+ table=f"{bq_dataset}.{bq_table}", max_results=bq_max_results
)
return row_iterator
@@ -785,6 +859,10 @@ def iter_outputs(
f"on your prediction output:\n{output_info}"
)
+ def wait_for_resource_creation(self) -> None:
+ """Waits until resource has been created."""
+ self._wait_for_resource_creation()
+
class _RunnableJob(_Job):
"""ABC to interface job as a runnable training class."""
@@ -800,7 +878,7 @@ def __init__(
Args:
project(str): Project of the resource noun.
location(str): The location of the resource noun.
- credentials(google.auth.crendentials.Crendentials): Optional custom
+ credentials(google.auth.credentials.Credentials): Optional custom
credentials to use when accessing interacting with resource noun.
"""
@@ -812,26 +890,95 @@ def __init__(
project=project, location=location
)
+ self._logged_web_access_uris = set()
+
+ @classmethod
+ def _empty_constructor(
+ cls,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ resource_name: Optional[str] = None,
+ ) -> "_RunnableJob":
+ """Initializes with all attributes set to None.
+
+ The attributes should be populated after a future is complete. This allows
+ scheduling of additional API calls before the resource is created.
+
+ Args:
+ project (str): Optional. Project of the resource noun.
+ location (str): Optional. The location of the resource noun.
+ credentials(google.auth.credentials.Credentials):
+ Optional. custom credentials to use when accessing interacting with
+ resource noun.
+ resource_name(str): Optional. A fully-qualified resource name or ID.
+ Returns:
+ An instance of this class with attributes set to None.
+ """
+ self = super()._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=resource_name,
+ )
+
+ self._logged_web_access_uris = set()
+ return self
+
+ @property
+ def web_access_uris(self) -> Dict[str, Union[str, Dict[str, str]]]:
+ """Fetch the runnable job again and return the latest web access uris.
+
+ Returns:
+ (Dict[str, Union[str, Dict[str, str]]]):
+ Web access uris of the runnable job.
+ """
+
+ # Fetch the Job again for most up-to-date web access uris
+ self._sync_gca_resource()
+ return self._get_web_access_uris()
+
@abc.abstractmethod
- def run(self) -> None:
+ def _get_web_access_uris(self):
+ """Helper method to get the web access uris of the runnable job"""
pass
- @property
- def _has_run(self) -> bool:
- """Property returns true if this class has a resource name."""
- return bool(self._gca_resource.name)
+ @abc.abstractmethod
+ def _log_web_access_uris(self):
+ """Helper method to log the web access uris of the runnable job"""
+ pass
- @property
- def state(self) -> gca_job_state.JobState:
- """Current state of job.
+ def _block_until_complete(self):
+ """Helper method to block and check on runnable job until complete.
Raises:
- RuntimeError if job run has not been called.
+ RuntimeError: If job failed or cancelled.
"""
- if not self._has_run:
- raise RuntimeError("Job has not run. No state available.")
- return super().state
+ log_wait = _LOG_WAIT_TIME
+
+ previous_time = time.time()
+ while self.state not in _JOB_COMPLETE_STATES:
+ current_time = time.time()
+ if current_time - previous_time >= _LOG_WAIT_TIME:
+ self._log_job_state()
+ log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
+ previous_time = current_time
+ self._log_web_access_uris()
+ time.sleep(_JOB_WAIT_TIME)
+
+ self._log_job_state()
+
+ # Error is only populated when the job state is
+ # JOB_STATE_FAILED or JOB_STATE_CANCELLED.
+ if self._gca_resource.state in _JOB_ERROR_STATES:
+ raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
+ else:
+ _LOGGER.log_action_completed_against_resource("run", "completed", self)
+
+ @abc.abstractmethod
+ def run(self) -> None:
+ pass
@classmethod
def get(
@@ -841,7 +988,7 @@ def get(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_RunnableJob":
- """Get an Vertex AI Job for the given resource_name.
+ """Get a Vertex AI Job for the given resource_name.
Args:
resource_name (str):
@@ -857,7 +1004,7 @@ def get(
credentials set in aiplatform.init.
Returns:
- An Vertex AI Job.
+ A Vertex AI Job.
"""
self = cls._empty_constructor(
project=project,
@@ -870,6 +1017,10 @@ def get(
return self
+ def wait_for_resource_creation(self) -> None:
+ """Waits until resource has been created."""
+ self._wait_for_resource_creation()
+
class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
@@ -878,6 +1029,8 @@ class DataLabelingJob(_Job):
_cancel_method = "cancel_data_labeling_job"
_delete_method = "delete_data_labeling_job"
_job_type = "labeling-tasks"
+ _parse_resource_name_method = "parse_data_labeling_job_path"
+ _format_resource_name_method = "data_labeling_job_path"
pass
@@ -886,22 +1039,27 @@ class CustomJob(_RunnableJob):
_resource_noun = "customJobs"
_getter_method = "get_custom_job"
- _list_method = "list_custom_job"
+ _list_method = "list_custom_jobs"
_cancel_method = "cancel_custom_job"
_delete_method = "delete_custom_job"
+ _parse_resource_name_method = "parse_custom_job_path"
+ _format_resource_name_method = "custom_job_path"
_job_type = "training"
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
worker_pool_specs: Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]],
+ base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
):
- """Cosntruct a Custom Job with Worker Pool Specs.
+ """Constructs a Custom Job with Worker Pool Specs.
```
Example usage:
@@ -923,7 +1081,8 @@ def __init__(
my_job = aiplatform.CustomJob(
display_name='my_job',
- worker_pool_specs=worker_pool_specs
+ worker_pool_specs=worker_pool_specs,
+ labels={'my_key': 'my_value'},
)
my_job.run()
@@ -942,6 +1101,9 @@ def __init__(
worker_pool_specs (Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]]):
Required. The spec of the worker pools including machine type and Docker image.
Can provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
+ base_output_dir (str):
+ Optional. GCS output directory of job. If not provided a
+ timestamped directory in the staging directory will be used.
project (str):
Optional.Project to run the custom job in. Overrides project set in aiplatform.init.
location (str):
@@ -949,6 +1111,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional.Custom credentials to use to run call custom job service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize CustomJobs.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
encryption_spec_key_name (str):
Optional.Customer-managed encryption key name for a
CustomJob. If this is set, then all resources
@@ -959,7 +1131,7 @@ def __init__(
staging_bucket set in aiplatform.init.
Raises:
- RuntimeError is not staging bucket was set using aiplatfrom.init and a staging
+ RuntimeError: If staging bucket was not set using aiplatform.init and a staging
bucket was not passed in.
"""
@@ -973,35 +1145,95 @@ def __init__(
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
)
+ if labels:
+ utils.validate_labels(labels)
+
+ # default directory if not given
+ base_output_dir = base_output_dir or utils._timestamped_gcs_dir(
+ staging_bucket, "aiplatform-custom-job"
+ )
+
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
+
self._gca_resource = gca_custom_job_compat.CustomJob(
display_name=display_name,
job_spec=gca_custom_job_compat.CustomJobSpec(
worker_pool_specs=worker_pool_specs,
base_output_directory=gca_io_compat.GcsDestination(
- output_uri_prefix=staging_bucket
+ output_uri_prefix=base_output_dir
),
),
+ labels=labels,
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name
),
)
+ @property
+ def network(self) -> Optional[str]:
+ """The full name of the Google Compute Engine
+ [network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
+ CustomJob should be peered.
+
+ Takes the format `projects/{project}/global/networks/{network}`. Where
+ {project} is a project number, as in `12345`, and {network} is a network name.
+
+ Private services access must already be configured for the network. If left
+ unspecified, the CustomJob is not peered with any network.
+ """
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.job_spec.network
+
+ def _get_web_access_uris(self) -> Dict[str, str]:
+ """Helper method to get the web access uris of the custom job
+
+ Returns:
+ (Dict[str, str]):
+ Web access uris of the custom job.
+ """
+ return dict(self._gca_resource.web_access_uris)
+
+ def _log_web_access_uris(self):
+ """Helper method to log the web access uris of the custom job"""
+
+ for worker, uri in self._get_web_access_uris().items():
+ if uri not in self._logged_web_access_uris:
+ _LOGGER.info(
+ "%s %s access the interactive shell terminals for the custom job:\n%s:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ worker,
+ uri,
+ ),
+ )
+ self._logged_web_access_uris.add(uri)
+
@classmethod
def from_local_script(
cls,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
script_path: str,
container_uri: str,
- args: Optional[List[Union[str, float, int]]] = None,
+ args: Optional[Sequence[str]] = None,
requirements: Optional[Sequence[str]] = None,
environment_variables: Optional[Dict[str, str]] = None,
replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
+ base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
) -> "CustomJob":
@@ -1017,6 +1249,7 @@ def from_local_script(
replica_count=1,
args=['--dataset', 'gs://my-bucket/my-dataset',
'--model_output_uri', 'gs://my-bucket/model']
+ labels={'my_key': 'my_value'},
)
job.run()
@@ -1029,7 +1262,7 @@ def from_local_script(
Required. Local path to training script.
container_uri (str):
Required: Uri of the training container image to use for custom job.
- args (Optional[List[Union[str, float, int]]]):
+ args (Optional[Sequence[str]]):
Optional. Command line arguments to be passed to the Python task.
requirements (Sequence[str]):
Optional. List of python packages dependencies of script.
@@ -1055,6 +1288,23 @@ def from_local_script(
NVIDIA_TESLA_T4
accelerator_count (int):
Optional. The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Optional. Type of the boot disk, default is `pd-ssd`.
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Optional. Size in GB of the boot disk, default is 100GB.
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ Optional. The type of machine to use for reduction server.
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
+ base_output_dir (str):
+ Optional. GCS output directory of job. If not provided a
+ timestamped directory in the staging directory will be used.
project (str):
Optional. Project to run the custom job in. Overrides project set in aiplatform.init.
location (str):
@@ -1062,6 +1312,16 @@ def from_local_script(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call custom job service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize CustomJobs.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
encryption_spec_key_name (str):
Optional. Customer-managed encryption key name for a
CustomJob. If this is set, then all resources
@@ -1072,7 +1332,7 @@ def from_local_script(
staging_bucket set in aiplatform.init.
Raises:
- RuntimeError is not staging bucket was set using aiplatfrom.init and a staging
+ RuntimeError: If staging bucket was not set using aiplatform.init and a staging
bucket was not passed in.
"""
@@ -1086,43 +1346,68 @@ def from_local_script(
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
)
- worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_count=accelerator_count,
- accelerator_type=accelerator_type,
- ).pool_specs
+ if labels:
+ utils.validate_labels(labels)
+
+ worker_pool_specs = (
+ worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
+ replica_count=replica_count,
+ machine_type=machine_type,
+ accelerator_count=accelerator_count,
+ accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
+ reduction_server_replica_count=reduction_server_replica_count,
+ reduction_server_machine_type=reduction_server_machine_type,
+ ).pool_specs
+ )
python_packager = source_utils._TrainingScriptPythonPackager(
script_path=script_path, requirements=requirements
)
package_gcs_uri = python_packager.package_and_copy_to_gcs(
- gcs_staging_dir=staging_bucket, project=project, credentials=credentials,
+ gcs_staging_dir=staging_bucket,
+ project=project,
+ credentials=credentials,
)
- for spec in worker_pool_specs:
- spec["python_package_spec"] = {
- "executor_image_uri": container_uri,
- "python_module": python_packager.module_name,
- "package_uris": [package_gcs_uri],
- }
+ for spec_order, spec in enumerate(worker_pool_specs):
+
+ if not spec:
+ continue
- if args:
- spec["python_package_spec"]["args"] = args
+ if (
+ spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
+ and reduction_server_replica_count > 0
+ ):
+ spec["container_spec"] = {
+ "image_uri": reduction_server_container_uri,
+ }
+ else:
+ spec["python_package_spec"] = {
+ "executor_image_uri": container_uri,
+ "python_module": python_packager.module_name,
+ "package_uris": [package_gcs_uri],
+ }
- if environment_variables:
- spec["python_package_spec"]["env"] = [
- {"name": key, "value": value}
- for key, value in environment_variables.items()
- ]
+ if args:
+ spec["python_package_spec"]["args"] = args
+
+ if environment_variables:
+ spec["python_package_spec"]["env"] = [
+ {"name": key, "value": value}
+ for key, value in environment_variables.items()
+ ]
return cls(
display_name=display_name,
worker_pool_specs=worker_pool_specs,
+ base_output_dir=base_output_dir,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
encryption_spec_key_name=encryption_spec_key_name,
staging_bucket=staging_bucket,
)
@@ -1134,8 +1419,10 @@ def run(
network: Optional[str] = None,
timeout: Optional[int] = None,
restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> None:
"""Run this configured CustomJob.
@@ -1155,8 +1442,12 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
- Optional. The name of an Vertex AI
+ Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
resource to which this CustomJob will upload Tensorboard
logs. Format:
@@ -1173,6 +1464,8 @@ def run(
sync (bool):
Whether to execute this method synchronously. If False, this method
will unblock and it will be executed in a concurrent Future.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
"""
if service_account:
@@ -1188,19 +1481,18 @@ def run(
restart_job_on_worker_restart=restart_job_on_worker_restart,
)
+ if enable_web_access:
+ self._gca_resource.job_spec.enable_web_access = enable_web_access
+
if tensorboard:
- v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
- v1beta1_gca_resource._pb.MergeFromString(
- self._gca_resource._pb.SerializeToString()
- )
- self._gca_resource = v1beta1_gca_resource
self._gca_resource.job_spec.tensorboard = tensorboard
_LOGGER.log_create_with_lro(self.__class__)
- version = "v1beta1" if tensorboard else "v1"
- self._gca_resource = self.api_client.select_version(version).create_custom_job(
- parent=self._parent, custom_job=self._gca_resource
+ self._gca_resource = self.api_client.create_custom_job(
+ parent=self._parent,
+ custom_job=self._gca_resource,
+ timeout=create_request_timeout,
)
_LOGGER.log_create_complete_with_getter(
@@ -1209,6 +1501,14 @@ def run(
_LOGGER.info("View Custom Job:\n%s" % self._dashboard_uri())
+ if tensorboard:
+ _LOGGER.info(
+ "View Tensorboard:\n%s"
+ % console_utils.custom_job_tensorboard_console_uri(
+ tensorboard, self.resource_name
+ )
+ )
+
self._block_until_complete()
@property
@@ -1236,10 +1536,13 @@ class HyperparameterTuningJob(_RunnableJob):
_list_method = "list_hyperparameter_tuning_jobs"
_cancel_method = "cancel_hyperparameter_tuning_job"
_delete_method = "delete_hyperparameter_tuning_job"
+ _parse_resource_name_method = "parse_hyperparameter_tuning_job_path"
+ _format_resource_name_method = "hyperparameter_tuning_job_path"
_job_type = "training"
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
custom_job: CustomJob,
metric_spec: Dict[str, str],
@@ -1252,6 +1555,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
):
"""
@@ -1280,7 +1584,8 @@ def __init__(
custom_job = aiplatform.CustomJob(
display_name='my_job',
- worker_pool_specs=worker_pool_specs
+ worker_pool_specs=worker_pool_specs,
+ labels={'my_key': 'my_value'},
)
@@ -1298,6 +1603,7 @@ def __init__(
},
max_trial_count=128,
parallel_trial_count=8,
+ labels={'my_key': 'my_value'},
)
hp_job.run()
@@ -1318,7 +1624,7 @@ def __init__(
Required. Configured CustomJob. The worker pool spec from this custom job
applies to the CustomJobs created in all the trials.
metric_spec: Dict[str, str]
- Required. Dicionary representing metrics to optimize. The dictionary key is the metric_id,
+ Required. Dictionary representing metrics to optimize. The dictionary key is the metric_id,
which is reported by your training job, and the dictionary value is the
optimization goal of the metric('minimize' or 'maximize'). example:
@@ -1326,7 +1632,7 @@ def __init__(
parameter_spec (Dict[str, hyperparameter_tuning._ParameterSpec]):
Required. Dictionary representing parameters to optimize. The dictionary key is the metric_id,
- which is passed into your training job as a command line key word arguemnt, and the
+ which is passed into your training job as a command line key word argument, and the
dictionary value is the parameter specification of the metric.
@@ -1343,7 +1649,7 @@ def __init__(
DoubleParameterSpec, IntegerParameterSpec, CategoricalParameterSpace, DiscreteParameterSpec
max_trial_count (int):
- Reuired. The desired total number of Trials.
+ Required. The desired total number of Trials.
parallel_trial_count (int):
Required. The desired number of Trials to run in parallel.
max_failed_trial_count (int):
@@ -1393,6 +1699,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call HyperparameterTuning service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize HyperparameterTuningJobs.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
encryption_spec_key_name (str):
Optional. Customer-managed encryption key options for a
HyperparameterTuningJob. If this is set, then
@@ -1423,18 +1739,71 @@ def __init__(
],
)
- self._gca_resource = gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
- display_name=display_name,
- study_spec=study_spec,
- max_trial_count=max_trial_count,
- parallel_trial_count=parallel_trial_count,
- max_failed_trial_count=max_failed_trial_count,
- trial_job_spec=copy.deepcopy(custom_job.job_spec),
- encryption_spec=initializer.global_config.get_encryption_spec(
- encryption_spec_key_name=encryption_spec_key_name
- ),
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
+
+ self._gca_resource = (
+ gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob(
+ display_name=display_name,
+ study_spec=study_spec,
+ max_trial_count=max_trial_count,
+ parallel_trial_count=parallel_trial_count,
+ max_failed_trial_count=max_failed_trial_count,
+ trial_job_spec=copy.deepcopy(custom_job.job_spec),
+ labels=labels,
+ encryption_spec=initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ ),
+ )
)
+ @property
+ def network(self) -> Optional[str]:
+ """The full name of the Google Compute Engine
+ [network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
+ HyperparameterTuningJob should be peered.
+
+ Takes the format `projects/{project}/global/networks/{network}`. Where
+ {project} is a project number, as in `12345`, and {network} is a network name.
+
+ Private services access must already be configured for the network. If left
+ unspecified, the HyperparameterTuningJob is not peered with any network.
+ """
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource.trial_job_spec, "network")
+
+ def _get_web_access_uris(self) -> Dict[str, Dict[str, str]]:
+ """Helper method to get the web access uris of the hyperparameter job
+
+ Returns:
+ (Dict[str, Dict[str, str]]):
+ Web access uris of the hyperparameter job.
+ """
+ web_access_uris = dict()
+ for trial in self.trials:
+ web_access_uris[trial.id] = web_access_uris.get(trial.id, dict())
+ for worker, uri in trial.web_access_uris.items():
+ web_access_uris[trial.id][worker] = uri
+ return web_access_uris
+
+ def _log_web_access_uris(self):
+ """Helper method to log the web access uris of the hyperparameter job"""
+
+ for trial_id, trial_web_access_uris in self._get_web_access_uris().items():
+ for worker, uri in trial_web_access_uris.items():
+ if uri not in self._logged_web_access_uris:
+ _LOGGER.info(
+ "%s %s access the interactive shell terminals for trial - %s:\n%s:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ trial_id,
+ worker,
+ uri,
+ ),
+ )
+ self._logged_web_access_uris.add(uri)
+
@base.optional_sync()
def run(
self,
@@ -1442,8 +1811,10 @@ def run(
network: Optional[str] = None,
timeout: Optional[int] = None, # seconds
restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
tensorboard: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> None:
"""Run this configured CustomJob.
@@ -1463,8 +1834,12 @@ def run(
gets restarted. This feature can be used by
distributed training jobs that are not resilient
to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
tensorboard (str):
- Optional. The name of an Vertex AI
+ Optional. The name of a Vertex AI
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
resource to which this CustomJob will upload Tensorboard
logs. Format:
@@ -1481,6 +1856,8 @@ def run(
sync (bool):
Whether to execute this method synchronously. If False, this method
will unblock and it will be executed in a concurrent Future.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
"""
if service_account:
@@ -1491,28 +1868,25 @@ def run(
if timeout or restart_job_on_worker_restart:
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
- self._gca_resource.trial_job_spec.scheduling = gca_custom_job_compat.Scheduling(
- timeout=duration,
- restart_job_on_worker_restart=restart_job_on_worker_restart,
+ self._gca_resource.trial_job_spec.scheduling = (
+ gca_custom_job_compat.Scheduling(
+ timeout=duration,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ )
)
+ if enable_web_access:
+ self._gca_resource.trial_job_spec.enable_web_access = enable_web_access
+
if tensorboard:
- v1beta1_gca_resource = (
- gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
- )
- v1beta1_gca_resource._pb.MergeFromString(
- self._gca_resource._pb.SerializeToString()
- )
- self._gca_resource = v1beta1_gca_resource
self._gca_resource.trial_job_spec.tensorboard = tensorboard
_LOGGER.log_create_with_lro(self.__class__)
- version = "v1beta1" if tensorboard else "v1"
- self._gca_resource = self.api_client.select_version(
- version
- ).create_hyperparameter_tuning_job(
- parent=self._parent, hyperparameter_tuning_job=self._gca_resource
+ self._gca_resource = self.api_client.create_hyperparameter_tuning_job(
+ parent=self._parent,
+ hyperparameter_tuning_job=self._gca_resource,
+ timeout=create_request_timeout,
)
_LOGGER.log_create_complete_with_getter(
@@ -1521,8 +1895,17 @@ def run(
_LOGGER.info("View HyperparameterTuningJob:\n%s" % self._dashboard_uri())
+ if tensorboard:
+ _LOGGER.info(
+ "View Tensorboard:\n%s"
+ % console_utils.custom_job_tensorboard_console_uri(
+ tensorboard, self.resource_name
+ )
+ )
+
self._block_until_complete()
@property
def trials(self) -> List[gca_study_compat.Trial]:
+ self._assert_gca_resource_is_available()
return list(self._gca_resource.trials)
diff --git a/google/cloud/aiplatform/matching_engine/__init__.py b/google/cloud/aiplatform/matching_engine/__init__.py
new file mode 100644
index 0000000000..4616d01bbd
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/__init__.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform.matching_engine.matching_engine_index import (
+ MatchingEngineIndex,
+)
+from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
+ BruteForceConfig as MatchingEngineBruteForceAlgorithmConfig,
+ MatchingEngineIndexConfig as MatchingEngineIndexConfig,
+ TreeAhConfig as MatchingEngineTreeAhAlgorithmConfig,
+)
+from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
+ MatchingEngineIndexEndpoint,
+)
+
+__all__ = (
+ "MatchingEngineIndex",
+ "MatchingEngineIndexEndpoint",
+ "MatchingEngineIndexConfig",
+ "MatchingEngineBruteForceAlgorithmConfig",
+ "MatchingEngineTreeAhAlgorithmConfig",
+)
diff --git a/google/cloud/aiplatform/matching_engine/_protos/match_service.proto b/google/cloud/aiplatform/matching_engine/_protos/match_service.proto
new file mode 100644
index 0000000000..158b0f146a
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/_protos/match_service.proto
@@ -0,0 +1,136 @@
+syntax = "proto3";
+
+package google.cloud.aiplatform.container.v1beta1;
+
+import "google/rpc/status.proto";
+
+// MatchService is a Google managed service for efficient vector similarity
+// search at scale.
+service MatchService {
+ // Returns the nearest neighbors for the query. If it is a sharded
+ // deployment, calls the other shards and aggregates the responses.
+ rpc Match(MatchRequest) returns (MatchResponse) {}
+
+ // Returns the nearest neighbors for batch queries. If it is a sharded
+ // deployment, calls the other shards and aggregates the responses.
+ rpc BatchMatch(BatchMatchRequest) returns (BatchMatchResponse) {}
+}
+
+// Parameters for a match query.
+message MatchRequest {
+ // The ID of the DeploydIndex that will serve the request.
+ // This MatchRequest is sent to a specific IndexEndpoint of the Control API,
+ // as per the IndexEndpoint.network. That IndexEndpoint also has
+ // IndexEndpoint.deployed_indexes, and each such index has an
+ // DeployedIndex.id field.
+ // The value of the field below must equal one of the DeployedIndex.id
+ // fields of the IndexEndpoint that is being called for this request.
+ string deployed_index_id = 1;
+
+ // The embedding values.
+ repeated float float_val = 2;
+
+ // The number of nearest neighbors to be retrieved from database for
+ // each query. If not set, will use the default from
+ // the service configuration.
+ int32 num_neighbors = 3;
+
+ // The list of restricts.
+ repeated Namespace restricts = 4;
+
+ // Crowding is a constraint on a neighbor list produced by nearest neighbor
+ // search requiring that no more than some value k' of the k neighbors
+ // returned have the same value of crowding_attribute.
+ // It's used for improving result diversity.
+ // This field is the maximum number of matches with the same crowding tag.
+ int32 per_crowding_attribute_num_neighbors = 5;
+
+ // The number of neighbors to find via approximate search before
+ // exact reordering is performed. If not set, the default value from scam
+ // config is used; if set, this value must be > 0.
+ int32 approx_num_neighbors = 6;
+
+ // The fraction of the number of leaves to search, set at query time allows
+ // user to tune search performance. This value increase result in both search
+ // accuracy and latency increase. The value should be between 0.0 and 1.0. If
+ // not set or set to 0.0, query uses the default value specified in
+ // NearestNeighborSearchConfig.TreeAHConfig.leaf_nodes_to_search_percent.
+ int32 leaf_nodes_to_search_percent_override = 7;
+}
+
+// Response of a match query.
+message MatchResponse {
+ message Neighbor {
+ // The ids of the matches.
+ string id = 1;
+
+ // The distances of the matches.
+ double distance = 2;
+ }
+ // All its neighbors.
+ repeated Neighbor neighbor = 1;
+}
+
+// Parameters for a batch match query.
+message BatchMatchRequest {
+ // Batched requests against one index.
+ message BatchMatchRequestPerIndex {
+ // The ID of the DeploydIndex that will serve the request.
+ string deployed_index_id = 1;
+
+ // The requests against the index identified by the above deployed_index_id.
+ repeated MatchRequest requests = 2;
+
+ // Selects the optimal batch size to use for low-level batching. Queries
+ // within each low level batch are executed sequentially while low level
+ // batches are executed in parallel.
+ // This field is optional, defaults to 0 if not set. A non-positive number
+ // disables low level batching, i.e. all queries are executed sequentially.
+ int32 low_level_batch_size = 3;
+ }
+
+ // The batch requests grouped by indexes.
+ repeated BatchMatchRequestPerIndex requests = 1;
+}
+
+// Response of a batch match query.
+message BatchMatchResponse {
+ // Batched responses for one index.
+ message BatchMatchResponsePerIndex {
+ // The ID of the DeployedIndex that produced the responses.
+ string deployed_index_id = 1;
+
+ // The match responses produced by the index identified by the above
+ // deployed_index_id. This field is set only when the query against that
+ // index succeed.
+ repeated MatchResponse responses = 2;
+
+ // The status of response for the batch query identified by the above
+ // deployed_index_id.
+ google.rpc.Status status = 3;
+ }
+
+ // The batched responses grouped by indexes.
+ repeated BatchMatchResponsePerIndex responses = 1;
+}
+
+// Namespace specifies the rules for determining the datapoints that are
+// eligible for each matching query, overall query is an AND across namespaces.
+message Namespace {
+ // The string name of the namespace that this proto is specifying,
+ // such as "color", "shape", "geo", or "tags".
+ string name = 1;
+
+ // The allowed tokens in the namespace.
+ repeated string allow_tokens = 2;
+
+ // The denied tokens in the namespace.
+ // The denied tokens have exactly the same format as the token fields, but
+ // represents a negation. When a token is denied, then matches will be
+ // excluded whenever the other datapoint has that token.
+ //
+ // For example, if a query specifies {color: red, blue, !purple}, then that
+ // query will match datapoints that are red or blue, but if those points are
+ // also purple, then they will be excluded even if they are red/blue.
+ repeated string deny_tokens = 3;
+}
diff --git a/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2.py b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2.py
new file mode 100644
index 0000000000..6b3ab988b2
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: google/cloud/aiplatform/matching_engine/_protos/match_service.proto
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from google.rpc import status_pb2 as google_dot_rpc_dot_status__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\nCgoogle/cloud/aiplatform/matching_engine/_protos/match_service.proto\x12)google.cloud.aiplatform.container.v1beta1\x1a\x17google/rpc/status.proto"\x97\x02\n\x0cMatchRequest\x12\x19\n\x11\x64\x65ployed_index_id\x18\x01 \x01(\t\x12\x11\n\tfloat_val\x18\x02 \x03(\x02\x12\x15\n\rnum_neighbors\x18\x03 \x01(\x05\x12G\n\trestricts\x18\x04 \x03(\x0b\x32\x34.google.cloud.aiplatform.container.v1beta1.Namespace\x12,\n$per_crowding_attribute_num_neighbors\x18\x05 \x01(\x05\x12\x1c\n\x14\x61pprox_num_neighbors\x18\x06 \x01(\x05\x12-\n%leaf_nodes_to_search_percent_override\x18\x07 \x01(\x05"\x8e\x01\n\rMatchResponse\x12S\n\x08neighbor\x18\x01 \x03(\x0b\x32\x41.google.cloud.aiplatform.container.v1beta1.MatchResponse.Neighbor\x1a(\n\x08Neighbor\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x02 \x01(\x01"\x9f\x02\n\x11\x42\x61tchMatchRequest\x12h\n\x08requests\x18\x01 \x03(\x0b\x32V.google.cloud.aiplatform.container.v1beta1.BatchMatchRequest.BatchMatchRequestPerIndex\x1a\x9f\x01\n\x19\x42\x61tchMatchRequestPerIndex\x12\x19\n\x11\x64\x65ployed_index_id\x18\x01 \x01(\t\x12I\n\x08requests\x18\x02 \x03(\x0b\x32\x37.google.cloud.aiplatform.container.v1beta1.MatchRequest\x12\x1c\n\x14low_level_batch_size\x18\x03 \x01(\x05"\xac\x02\n\x12\x42\x61tchMatchResponse\x12k\n\tresponses\x18\x01 \x03(\x0b\x32X.google.cloud.aiplatform.container.v1beta1.BatchMatchResponse.BatchMatchResponsePerIndex\x1a\xa8\x01\n\x1a\x42\x61tchMatchResponsePerIndex\x12\x19\n\x11\x64\x65ployed_index_id\x18\x01 \x01(\t\x12K\n\tresponses\x18\x02 \x03(\x0b\x32\x38.google.cloud.aiplatform.container.v1beta1.MatchResponse\x12"\n\x06status\x18\x03 \x01(\x0b\x32\x12.google.rpc.Status"D\n\tNamespace\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x14\n\x0c\x61llow_tokens\x18\x02 \x03(\t\x12\x13\n\x0b\x64\x65ny_tokens\x18\x03 \x03(\t2\x9a\x02\n\x0cMatchService\x12|\n\x05Match\x12\x37.google.cloud.aiplatform.container.v1beta1.MatchRequest\x1a\x38.google.cloud.aiplatform.container.v1beta1.MatchResponse"\x00\x12\x8b\x01\n\nBatchMatch\x12<.google.cloud.aiplatform.container.v1beta1.BatchMatchRequest\x1a=.google.cloud.aiplatform.container.v1beta1.BatchMatchResponse"\x00\x62\x06proto3'
+)
+
+
+_MATCHREQUEST = DESCRIPTOR.message_types_by_name["MatchRequest"]
+_MATCHRESPONSE = DESCRIPTOR.message_types_by_name["MatchResponse"]
+_MATCHRESPONSE_NEIGHBOR = _MATCHRESPONSE.nested_types_by_name["Neighbor"]
+_BATCHMATCHREQUEST = DESCRIPTOR.message_types_by_name["BatchMatchRequest"]
+_BATCHMATCHREQUEST_BATCHMATCHREQUESTPERINDEX = _BATCHMATCHREQUEST.nested_types_by_name[
+ "BatchMatchRequestPerIndex"
+]
+_BATCHMATCHRESPONSE = DESCRIPTOR.message_types_by_name["BatchMatchResponse"]
+_BATCHMATCHRESPONSE_BATCHMATCHRESPONSEPERINDEX = (
+ _BATCHMATCHRESPONSE.nested_types_by_name["BatchMatchResponsePerIndex"]
+)
+_NAMESPACE = DESCRIPTOR.message_types_by_name["Namespace"]
+MatchRequest = _reflection.GeneratedProtocolMessageType(
+ "MatchRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MATCHREQUEST,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.MatchRequest)
+ },
+)
+_sym_db.RegisterMessage(MatchRequest)
+
+MatchResponse = _reflection.GeneratedProtocolMessageType(
+ "MatchResponse",
+ (_message.Message,),
+ {
+ "Neighbor": _reflection.GeneratedProtocolMessageType(
+ "Neighbor",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MATCHRESPONSE_NEIGHBOR,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.MatchResponse.Neighbor)
+ },
+ ),
+ "DESCRIPTOR": _MATCHRESPONSE,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.MatchResponse)
+ },
+)
+_sym_db.RegisterMessage(MatchResponse)
+_sym_db.RegisterMessage(MatchResponse.Neighbor)
+
+BatchMatchRequest = _reflection.GeneratedProtocolMessageType(
+ "BatchMatchRequest",
+ (_message.Message,),
+ {
+ "BatchMatchRequestPerIndex": _reflection.GeneratedProtocolMessageType(
+ "BatchMatchRequestPerIndex",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _BATCHMATCHREQUEST_BATCHMATCHREQUESTPERINDEX,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.BatchMatchRequest.BatchMatchRequestPerIndex)
+ },
+ ),
+ "DESCRIPTOR": _BATCHMATCHREQUEST,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.BatchMatchRequest)
+ },
+)
+_sym_db.RegisterMessage(BatchMatchRequest)
+_sym_db.RegisterMessage(BatchMatchRequest.BatchMatchRequestPerIndex)
+
+BatchMatchResponse = _reflection.GeneratedProtocolMessageType(
+ "BatchMatchResponse",
+ (_message.Message,),
+ {
+ "BatchMatchResponsePerIndex": _reflection.GeneratedProtocolMessageType(
+ "BatchMatchResponsePerIndex",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _BATCHMATCHRESPONSE_BATCHMATCHRESPONSEPERINDEX,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.BatchMatchResponse.BatchMatchResponsePerIndex)
+ },
+ ),
+ "DESCRIPTOR": _BATCHMATCHRESPONSE,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.BatchMatchResponse)
+ },
+)
+_sym_db.RegisterMessage(BatchMatchResponse)
+_sym_db.RegisterMessage(BatchMatchResponse.BatchMatchResponsePerIndex)
+
+Namespace = _reflection.GeneratedProtocolMessageType(
+ "Namespace",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _NAMESPACE,
+ "__module__": "google.cloud.aiplatform.matching_engine._protos.match_service_pb2"
+ # @@protoc_insertion_point(class_scope:google.cloud.aiplatform.container.v1beta1.Namespace)
+ },
+)
+_sym_db.RegisterMessage(Namespace)
+
+_MATCHSERVICE = DESCRIPTOR.services_by_name["MatchService"]
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+ DESCRIPTOR._options = None
+ _MATCHREQUEST._serialized_start = 140
+ _MATCHREQUEST._serialized_end = 419
+ _MATCHRESPONSE._serialized_start = 422
+ _MATCHRESPONSE._serialized_end = 564
+ _MATCHRESPONSE_NEIGHBOR._serialized_start = 524
+ _MATCHRESPONSE_NEIGHBOR._serialized_end = 564
+ _BATCHMATCHREQUEST._serialized_start = 567
+ _BATCHMATCHREQUEST._serialized_end = 854
+ _BATCHMATCHREQUEST_BATCHMATCHREQUESTPERINDEX._serialized_start = 695
+ _BATCHMATCHREQUEST_BATCHMATCHREQUESTPERINDEX._serialized_end = 854
+ _BATCHMATCHRESPONSE._serialized_start = 857
+ _BATCHMATCHRESPONSE._serialized_end = 1157
+ _BATCHMATCHRESPONSE_BATCHMATCHRESPONSEPERINDEX._serialized_start = 989
+ _BATCHMATCHRESPONSE_BATCHMATCHRESPONSEPERINDEX._serialized_end = 1157
+ _NAMESPACE._serialized_start = 1159
+ _NAMESPACE._serialized_end = 1227
+ _MATCHSERVICE._serialized_start = 1230
+ _MATCHSERVICE._serialized_end = 1512
+# @@protoc_insertion_point(module_scope)
diff --git a/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py
new file mode 100644
index 0000000000..2c0c14f8ed
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+
+"""Client and server classes corresponding to protobuf-defined services."""
+from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
+
+import grpc
+
+
+class MatchServiceStub(object):
+ """MatchService is a Google managed service for efficient vector similarity
+ search at scale.
+ """
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.Match = channel.unary_unary(
+ "/google.cloud.aiplatform.container.v1beta1.MatchService/Match",
+ request_serializer=match_service_pb2.MatchRequest.SerializeToString,
+ response_deserializer=match_service_pb2.MatchResponse.FromString,
+ )
+ self.BatchMatch = channel.unary_unary(
+ "/google.cloud.aiplatform.container.v1beta1.MatchService/BatchMatch",
+ request_serializer=match_service_pb2.BatchMatchRequest.SerializeToString,
+ response_deserializer=match_service_pb2.BatchMatchResponse.FromString,
+ )
+
+
+class MatchServiceServicer(object):
+ """MatchService is a Google managed service for efficient vector similarity
+ search at scale.
+ """
+
+ def Match(self, request, context):
+ """Returns the nearest neighbors for the query. If it is a sharded
+ deployment, calls the other shards and aggregates the responses.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def BatchMatch(self, request, context):
+ """Returns the nearest neighbors for batch queries. If it is a sharded
+ deployment, calls the other shards and aggregates the responses.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+
+def add_MatchServiceServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ "Match": grpc.unary_unary_rpc_method_handler(
+ servicer.Match,
+ request_deserializer=match_service_pb2.MatchRequest.FromString,
+ response_serializer=match_service_pb2.MatchResponse.SerializeToString,
+ ),
+ "BatchMatch": grpc.unary_unary_rpc_method_handler(
+ servicer.BatchMatch,
+ request_deserializer=match_service_pb2.BatchMatchRequest.FromString,
+ response_serializer=match_service_pb2.BatchMatchResponse.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ "google.cloud.aiplatform.container.v1beta1.MatchService", rpc_method_handlers
+ )
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+# This class is part of an EXPERIMENTAL API.
+class MatchService(object):
+ """MatchService is a Google managed service for efficient vector similarity
+ search at scale.
+ """
+
+ @staticmethod
+ def Match(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/google.cloud.aiplatform.container.v1beta1.MatchService/Match",
+ match_service_pb2.MatchRequest.SerializeToString,
+ match_service_pb2.MatchResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def BatchMatch(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/google.cloud.aiplatform.container.v1beta1.MatchService/BatchMatch",
+ match_service_pb2.BatchMatchRequest.SerializeToString,
+ match_service_pb2.BatchMatchResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py
new file mode 100644
index 0000000000..d382a126f1
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py
@@ -0,0 +1,606 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from google.auth import credentials as auth_credentials
+from google.protobuf import field_mask_pb2
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform.compat.types import (
+ matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
+ matching_engine_index as gca_matching_engine_index,
+)
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform.matching_engine import matching_engine_index_config
+from google.cloud.aiplatform import utils
+
+_LOGGER = base.Logger(__name__)
+
+
+class MatchingEngineIndex(base.VertexAiResourceNounWithFutureManager):
+ """Matching Engine index resource for Vertex AI."""
+
+ client_class = utils.IndexClientWithOverride
+
+ _resource_noun = "indexes"
+ _getter_method = "get_index"
+ _list_method = "list_indexes"
+ _delete_method = "delete_index"
+ _parse_resource_name_method = "parse_index_path"
+ _format_resource_name_method = "index_path"
+
+ def __init__(
+ self,
+ index_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing index given an index name or ID.
+
+ Example Usage:
+
+ my_index = aiplatform.MatchingEngineIndex(
+ index_name='projects/123/locations/us-central1/indexes/my_index_id'
+ )
+ or
+ my_index = aiplatform.MatchingEngineIndex(
+ index_name='my_index_id'
+ )
+
+ Args:
+ index_name (str):
+ Required. A fully-qualified index resource name or a index ID.
+ Example: "projects/123/locations/us-central1/indexes/my_index_id"
+ or "my_index_id" when project and location are initialized or passed.
+ project (str):
+ Optional. Project to retrieve index from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve index from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Index. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=index_name,
+ )
+ self._gca_resource = self._get_gca_resource(resource_name=index_name)
+
+ @property
+ def description(self) -> str:
+ """Description of the index."""
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.description
+
+ @classmethod
+ @base.optional_sync()
+ def _create(
+ cls,
+ display_name: str,
+ contents_delta_uri: str,
+ config: matching_engine_index_config.MatchingEngineIndexConfig,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ) -> "MatchingEngineIndex":
+ """Creates a MatchingEngineIndex resource.
+
+ Args:
+ display_name (str):
+ Required. The display name of the Index.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ contents_delta_uri (str):
+ Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
+ The string must be a valid Google Cloud Storage directory path. If this
+ field is set when calling IndexService.UpdateIndex, then no other
+ Index field can be also updated as part of the same call.
+ The expected structure and format of the files this URI points to is
+ described at
+ https://docs.google.com/document/d/12DLVB6Nq6rdv8grxfBsPhUA283KWrQ9ZenPBp0zUC30
+ config (matching_engine_index_config.MatchingEngineIndexConfig):
+ Required. The configuration with regard to the algorithms used for efficient search.
+ description (str):
+ Optional. The description of the Index.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Index.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one
+ Index(System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec (str):
+ Optional. Customer-managed encryption key
+ spec for data storage. If set, both of the
+ online and offline data storage will be secured
+ by this key.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ MatchingEngineIndex - Index resource object
+
+ """
+ gapic_index = gca_matching_engine_index.Index(
+ display_name=display_name,
+ description=description,
+ metadata={
+ "config": config.as_dict(),
+ "contentsDeltaUri": contents_delta_uri,
+ },
+ )
+
+ if labels:
+ utils.validate_labels(labels)
+ gapic_index.labels = labels
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ create_lro = api_client.create_index(
+ parent=initializer.global_config.common_location_path(
+ project=project, location=location
+ ),
+ index=gapic_index,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_create_with_lro(cls, create_lro)
+
+ created_index = create_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_index, "index")
+
+ index_obj = cls(
+ index_name=created_index.name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return index_obj
+
+ def update_metadata(
+ self,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ) -> "MatchingEngineIndex":
+ """Updates the metadata for this index.
+
+ Args:
+ display_name (str):
+ Optional. The display name of the Index.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ description (str):
+ Optional. The description of the Index.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Indexs.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one Index
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+
+ Returns:
+ MatchingEngineIndex - The updated index resource object.
+ """
+
+ self.wait()
+
+ update_mask = list()
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ if display_name is not None:
+ update_mask.append("display_name")
+
+ if description is not None:
+ update_mask.append("description")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_index = gca_matching_engine_index.Index(
+ name=self.resource_name,
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "index",
+ self,
+ )
+
+ update_lro = self.api_client.update_index(
+ index=gapic_index,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "index", self.__class__, update_lro
+ )
+
+ self._gca_resource = update_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("index", "Updated", self)
+
+ return self
+
+ def update_embeddings(
+ self,
+ contents_delta_uri: str,
+ is_complete_overwrite: Optional[bool] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ) -> "MatchingEngineIndex":
+ """Updates the embeddings for this index.
+
+ Args:
+ contents_delta_uri (str):
+ Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
+ The string must be a valid Google Cloud Storage directory path. If this
+ field is set when calling IndexService.UpdateIndex, then no other
+ Index field can be also updated as part of the same call.
+ The expected structure and format of the files this URI points to is
+ described at
+ https://docs.google.com/document/d/12DLVB6Nq6rdv8grxfBsPhUA283KWrQ9ZenPBp0zUC30
+ is_complete_overwrite (bool):
+ Optional. If this field is set together with contentsDeltaUri when calling IndexService.UpdateIndex,
+ then existing content of the Index will be replaced by the data from the contentsDeltaUri.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+
+ Returns:
+ MatchingEngineIndex - The updated index resource object.
+ """
+
+ self.wait()
+
+ update_mask = list()
+
+ if contents_delta_uri or is_complete_overwrite:
+ update_mask.append("metadata")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_index = gca_matching_engine_index.Index(
+ name=self.resource_name,
+ metadata={
+ "contentsDeltaUri": contents_delta_uri,
+ "isCompleteOverwrite": is_complete_overwrite,
+ },
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "index",
+ self,
+ )
+
+ update_lro = self.api_client.update_index(
+ index=gapic_index,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "index", self.__class__, update_lro
+ )
+
+ self._gca_resource = update_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("index", "Updated", self)
+
+ return self
+
+ @property
+ def deployed_indexes(
+ self,
+ ) -> List[gca_matching_engine_deployed_index_ref.DeployedIndexRef]:
+ """Returns a list of deployed index references that originate from this index.
+
+ Returns:
+ List[gca_matching_engine_deployed_index_ref.DeployedIndexRef] - Deployed index references
+ """
+
+ self.wait()
+
+ return self._gca_resource.deployed_indexes
+
+ @classmethod
+ def create_tree_ah_index(
+ cls,
+ display_name: str,
+ contents_delta_uri: str,
+ dimensions: int,
+ approximate_neighbors_count: int,
+ leaf_node_embedding_count: Optional[int] = None,
+ leaf_nodes_to_search_percent: Optional[float] = None,
+ distance_measure_type: Optional[
+ matching_engine_index_config.DistanceMeasureType
+ ] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ) -> "MatchingEngineIndex":
+ """Creates a MatchingEngineIndex resource that uses the tree-AH algorithm.
+
+ Example Usage:
+
+ my_index = aiplatform.Index.create_tree_ah_index(
+ display_name="my_display_name",
+ contents_delta_uri="gs://my_bucket/embeddings",
+ dimensions=1,
+ approximate_neighbors_count=150,
+ distance_measure_type="SQUARED_L2_DISTANCE",
+ leaf_node_embedding_count=100,
+ leaf_nodes_to_search_percent=50,
+ description="my description",
+ labels={ "label_name": "label_value" },
+ )
+
+ Args:
+ display_name (str):
+ Required. The display name of the Index.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ contents_delta_uri (str):
+ Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
+ The string must be a valid Google Cloud Storage directory path. If this
+ field is set when calling IndexService.UpdateIndex, then no other
+ Index field can be also updated as part of the same call.
+ The expected structure and format of the files this URI points to is
+ described at
+ https://docs.google.com/document/d/12DLVB6Nq6rdv8grxfBsPhUA283KWrQ9ZenPBp0zUC30
+ dimensions (int):
+ Required. The number of dimensions of the input vectors.
+ approximate_neighbors_count (int):
+ Required. The default number of neighbors to find via approximate search before exact reordering is
+ performed. Exact reordering is a procedure where results returned by an
+ approximate search algorithm are reordered via a more expensive distance computation.
+ leaf_node_embedding_count (int):
+ Optional. Number of embeddings on each leaf node. The default value is 1000 if not set.
+ leaf_nodes_to_search_percent (float):
+ Optional. The default percentage of leaf nodes that any query may be searched. Must be in
+ range 1-100, inclusive. The default value is 10 (means 10%) if not set.
+ distance_measure_type (matching_engine_index_config.DistanceMeasureType):
+ Optional. The distance measure used in nearest neighbor search.
+ description (str):
+ Optional. The description of the Index.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Index.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one
+ Index(System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec (str):
+ Optional. Customer-managed encryption key
+ spec for data storage. If set, both of the
+ online and offline data storage will be secured
+ by this key.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ MatchingEngineIndex - Index resource object
+
+ """
+
+ algorithm_config = matching_engine_index_config.TreeAhConfig(
+ leaf_node_embedding_count=leaf_node_embedding_count,
+ leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
+ )
+
+ config = matching_engine_index_config.MatchingEngineIndexConfig(
+ dimensions=dimensions,
+ algorithm_config=algorithm_config,
+ approximate_neighbors_count=approximate_neighbors_count,
+ distance_measure_type=distance_measure_type,
+ )
+
+ return cls._create(
+ display_name=display_name,
+ contents_delta_uri=contents_delta_uri,
+ config=config,
+ description=description,
+ labels=labels,
+ project=project,
+ location=location,
+ credentials=credentials,
+ request_metadata=request_metadata,
+ sync=sync,
+ )
+
+ @classmethod
+ def create_brute_force_index(
+ cls,
+ display_name: str,
+ contents_delta_uri: str,
+ dimensions: int,
+ distance_measure_type: Optional[
+ matching_engine_index_config.DistanceMeasureType
+ ] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ) -> "MatchingEngineIndex":
+ """Creates a MatchingEngineIndex resource that uses the brute force algorithm.
+
+ Example Usage:
+
+ my_index = aiplatform.Index.create_brute_force_index(
+ display_name="my_display_name",
+ contents_delta_uri="gs://my_bucket/embeddings",
+ dimensions=1,
+ approximate_neighbors_count=150,
+ distance_measure_type="SQUARED_L2_DISTANCE",
+ description="my description",
+ labels={ "label_name": "label_value" },
+ )
+
+ Args:
+ display_name (str):
+ Required. The display name of the Index.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ contents_delta_uri (str):
+ Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
+ The string must be a valid Google Cloud Storage directory path. If this
+ field is set when calling IndexService.UpdateIndex, then no other
+ Index field can be also updated as part of the same call.
+ The expected structure and format of the files this URI points to is
+ described at
+ https://docs.google.com/document/d/12DLVB6Nq6rdv8grxfBsPhUA283KWrQ9ZenPBp0zUC30
+ dimensions (int):
+ Required. The number of dimensions of the input vectors.
+ distance_measure_type (matching_engine_index_config.DistanceMeasureType):
+ Optional. The distance measure used in nearest neighbor search.
+ description (str):
+ Optional. The description of the Index.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Index.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one
+ Index(System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec (str):
+ Optional. Customer-managed encryption key
+ spec for data storage. If set, both of the
+ online and offline data storage will be secured
+ by this key.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ MatchingEngineIndex - Index resource object
+
+ """
+
+ algorithm_config = matching_engine_index_config.BruteForceConfig()
+
+ config = matching_engine_index_config.MatchingEngineIndexConfig(
+ dimensions=dimensions,
+ algorithm_config=algorithm_config,
+ distance_measure_type=distance_measure_type,
+ )
+
+ return cls._create(
+ display_name=display_name,
+ contents_delta_uri=contents_delta_uri,
+ config=config,
+ description=description,
+ labels=labels,
+ project=project,
+ location=location,
+ credentials=credentials,
+ request_metadata=request_metadata,
+ sync=sync,
+ )
diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py
new file mode 100644
index 0000000000..d0552f8de1
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py
@@ -0,0 +1,142 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+import enum
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+
+# This file mirrors the configuration options as defined in gs://google-cloud-aiplatform/schema/matchingengine/metadata/nearest_neighbor_search_1.0.0.yaml
+class DistanceMeasureType(enum.Enum):
+ """The distance measure used in nearest neighbor search."""
+
+ # Dot Product Distance. Defined as a negative of the dot product
+ DOT_PRODUCT_DISTANCE = "DOT_PRODUCT_DISTANCE"
+ # Euclidean (L_2) Distance
+ SQUARED_L2_DISTANCE = "SQUARED_L2_DISTANCE"
+ # Manhattan (L_1) Distance
+ L1_DISTANCE = "L1_DISTANCE"
+ # Cosine Distance. Defined as 1 - cosine similarity.
+ COSINE_DISTANCE = "COSINE_DISTANCE"
+
+
+class FeatureNormType(enum.Enum):
+ """Type of normalization to be carried out on each vector."""
+
+ # Unit L2 normalization type.
+ UNIT_L2_NORM = "UNIT_L2_NORM"
+ # No normalization type is specified.
+ NONE = "NONE"
+
+
+class AlgorithmConfig(abc.ABC):
+ """Base class for configuration options for matching algorithm."""
+
+ def as_dict(self) -> Dict:
+ """Returns the configuration as a dictionary.
+
+ Returns:
+ Dict[str, Any]
+ """
+ pass
+
+
+@dataclass
+class TreeAhConfig(AlgorithmConfig):
+ """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing).
+ Please refer to this paper for more details: https://arxiv.org/abs/1908.10396
+
+ Args:
+ leaf_node_embedding_count (int):
+ Optional. Number of embeddings on each leaf node. The default value is 1000 if not set.
+ leaf_nodes_to_search_percent (float):
+ The default percentage of leaf nodes that any query may be searched. Must be in
+ range 1-100, inclusive. The default value is 10 (means 10%) if not set.
+ """
+
+ leaf_node_embedding_count: Optional[int] = None
+ leaf_nodes_to_search_percent: Optional[float] = None
+
+ def as_dict(self) -> Dict:
+ """Returns the configuration as a dictionary.
+
+ Returns:
+ Dict[str, Any]
+ """
+
+ return {
+ "treeAhConfig": {
+ "leafNodeEmbeddingCount": self.leaf_node_embedding_count,
+ "leafNodesToSearchPercent": self.leaf_nodes_to_search_percent,
+ }
+ }
+
+
+@dataclass
+class BruteForceConfig(AlgorithmConfig):
+ """Configuration options for using brute force search, which simply
+ implements the standard linear search in the database for each query.
+ """
+
+ def as_dict(self) -> Dict:
+ """Returns the configuration as a dictionary.
+
+ Returns:
+ Dict[str, Any]
+ """
+ return {"bruteForceConfig": {}}
+
+
+@dataclass
+class MatchingEngineIndexConfig:
+ """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing).
+ Please refer to this paper for more details: https://arxiv.org/abs/1908.10396
+
+ Args:
+ dimensions (int):
+ Required. The number of dimensions of the input vectors.
+ algorithm_config (AlgorithmConfig):
+ Required. The configuration with regard to the algorithms used for efficient search.
+ approximate_neighbors_count (int):
+ Optional. The default number of neighbors to find via approximate search before exact reordering is
+ performed. Exact reordering is a procedure where results returned by an
+ approximate search algorithm are reordered via a more expensive distance computation.
+
+ Required if tree-AH algorithm is used.
+ distance_measure_type (DistanceMeasureType):
+ Optional. The distance measure used in nearest neighbor search.
+ """
+
+ dimensions: int
+ algorithm_config: AlgorithmConfig
+ approximate_neighbors_count: Optional[int] = None
+ distance_measure_type: Optional[DistanceMeasureType] = None
+
+ def as_dict(self) -> Dict[str, Any]:
+ """Returns the configuration as a dictionary.
+
+ Returns:
+ Dict[str, Any]
+ """
+
+ return {
+ "dimensions": self.dimensions,
+ "algorithmConfig": self.algorithm_config.as_dict(),
+ "approximateNeighborsCount": self.approximate_neighbors_count,
+ "distanceMeasureType": self.distance_measure_type,
+ }
diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py
new file mode 100644
index 0000000000..da155496ae
--- /dev/null
+++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py
@@ -0,0 +1,861 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Sequence, Tuple
+
+from google.auth import credentials as auth_credentials
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import matching_engine
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.compat.types import (
+ machine_resources as gca_machine_resources_compat,
+ matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
+)
+from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
+from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
+from google.protobuf import field_mask_pb2
+
+import grpc
+
+_LOGGER = base.Logger(__name__)
+
+
+@dataclass
+class MatchNeighbor:
+ """The id and distance of a nearest neighbor match for a given query embedding.
+
+ Args:
+ id (str):
+ Required. The id of the neighbor.
+ distance (float):
+ Required. The distance to the query embedding.
+ """
+
+ id: str
+ distance: float
+
+
+class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
+ """Matching Engine index endpoint resource for Vertex AI."""
+
+ client_class = utils.IndexEndpointClientWithOverride
+
+ _resource_noun = "indexEndpoints"
+ _getter_method = "get_index_endpoint"
+ _list_method = "list_index_endpoints"
+ _delete_method = "delete_index_endpoint"
+ _parse_resource_name_method = "parse_index_endpoint_path"
+ _format_resource_name_method = "index_endpoint_path"
+
+ def __init__(
+ self,
+ index_endpoint_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing index endpoint given a name or ID.
+
+ Example Usage:
+
+ my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
+ index_endpoint_name='projects/123/locations/us-central1/index_endpoint/my_index_id'
+ )
+ or
+ my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
+ index_endpoint_name='my_index_endpoint_id'
+ )
+
+ Args:
+ index_endpoint_name (str):
+ Required. A fully-qualified index endpoint resource name or a index ID.
+ Example: "projects/123/locations/us-central1/index_endpoints/my_index_id"
+ or "my_index_id" when project and location are initialized or passed.
+ project (str):
+ Optional. Project to retrieve index endpoint from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve index endpoint from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this IndexEndpoint. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=index_endpoint_name,
+ )
+ self._gca_resource = self._get_gca_resource(resource_name=index_endpoint_name)
+
+ @classmethod
+ @base.optional_sync()
+ def create(
+ cls,
+ display_name: str,
+ network: str,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync: bool = True,
+ ) -> "MatchingEngineIndexEndpoint":
+ """Creates a MatchingEngineIndexEndpoint resource.
+
+ Example Usage:
+
+ my_index_endpoint = aiplatform.IndexEndpoint.create(
+ display_name='my_endpoint',
+ )
+
+ Args:
+ display_name (str):
+ Required. The display name of the IndexEndpoint.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ network (str):
+ Required. The full name of the Google Compute Engine
+ `network `__
+ to which the IndexEndpoint should be peered.
+
+ Private services access must already be configured for the
+ network. If left unspecified, the Endpoint is not peered
+ with any network.
+
+ `Format `__:
+ projects/{project}/global/networks/{network}. Where
+ {project} is a project number, as in '12345', and {network}
+ is network name.
+ description (str):
+ Optional. The description of the IndexEndpoint.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your IndexEndpoint.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one
+ IndexEndpoint (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ project (str):
+ Optional. Project to create EntityType in. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create EntityType in. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create EntityTypes. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ sync (bool):
+ Optional. Whether to execute this creation synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+
+ Returns:
+ MatchingEngineIndexEndpoint - IndexEndpoint resource object
+
+ """
+ gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
+ display_name=display_name,
+ description=description,
+ network=network,
+ )
+
+ if labels:
+ utils.validate_labels(labels)
+ gapic_index_endpoint.labels = labels
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ create_lro = api_client.create_index_endpoint(
+ parent=initializer.global_config.common_location_path(
+ project=project, location=location
+ ),
+ index_endpoint=gapic_index_endpoint,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_create_with_lro(cls, create_lro)
+
+ created_index = create_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_index, "index_endpoint")
+
+ index_obj = cls(
+ index_endpoint_name=created_index.name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ return index_obj
+
+ def update(
+ self,
+ display_name: str,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ) -> "MatchingEngineIndexEndpoint":
+ """Updates an existing index endpoint resource.
+
+ Args:
+ display_name (str):
+ Required. The display name of the IndexEndpoint.
+ The name can be up to 128 characters long and
+ can be consist of any UTF-8 characters.
+ description (str):
+ Optional. The description of the IndexEndpoint.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined
+ metadata to organize your Indexs.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ on and examples of labels. No more than 64 user
+ labels can be associated with one IndexEndpoint
+ (System labels are excluded)."
+ System reserved label keys are prefixed with
+ "aiplatform.googleapis.com/" and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+
+ Returns:
+ MatchingEngineIndexEndpoint - The updated index endpoint resource object.
+ """
+
+ self.wait()
+
+ update_mask = list()
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ if display_name is not None:
+ update_mask.append("display_name")
+
+ if description is not None:
+ update_mask.append("description")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
+ name=self.resource_name,
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ )
+
+ self._gca_resource = self.api_client.update_index_endpoint(
+ index_endpoint=gapic_index_endpoint,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ )
+
+ return self
+
+ @staticmethod
+ def _build_deployed_index(
+ deployed_index_id: str,
+ index_resource_name: Optional[str] = None,
+ display_name: Optional[str] = None,
+ machine_type: Optional[str] = None,
+ min_replica_count: Optional[int] = None,
+ max_replica_count: Optional[int] = None,
+ enable_access_logging: Optional[bool] = None,
+ reserved_ip_ranges: Optional[Sequence[str]] = None,
+ deployment_group: Optional[str] = None,
+ auth_config_audiences: Optional[Sequence[str]] = None,
+ auth_config_allowed_issuers: Optional[Sequence[str]] = None,
+ ) -> gca_matching_engine_index_endpoint.DeployedIndex:
+ """Builds a DeployedIndex.
+
+ Args:
+ deployed_index_id (str):
+ Required. The user specified ID of the
+ DeployedIndex. The ID can be up to 128
+ characters long and must start with a letter and
+ only contain letters, numbers, and underscores.
+ The ID must be unique within the project it is
+ created in.
+ index_resource_name (str):
+ Optional. A fully-qualified index endpoint resource name or a index ID.
+ Example: "projects/123/locations/us-central1/index_endpoints/my_index_id"
+ display_name (str):
+ Optional. The display name of the DeployedIndex. If not provided upon
+ creation, the Index's display_name is used.
+ machine_type (str):
+ Optional. The type of machine. Not specifying machine type will
+ result in model to be deployed with automatic resources.
+ min_replica_count (int):
+ Optional. The minimum number of machine replicas this deployed
+ model will be always deployed on. If traffic against it increases,
+ it may dynamically be deployed onto more replicas, and as traffic
+ decreases, some of these extra replicas may be freed.
+
+ If this value is not provided, the value of 2 will be used.
+ max_replica_count (int):
+ Optional. The maximum number of replicas this deployed model may
+ be deployed on when the traffic against it increases. If requested
+ value is too large, the deployment will error, but if deployment
+ succeeds then the ability to scale the model to that many replicas
+ is guaranteed (barring service outages). If traffic against the
+ deployed model increases beyond what its replicas at maximum may
+ handle, a portion of the traffic will be dropped. If this value
+ is not provided, the larger value of min_replica_count or 2 will
+ be used. If value provided is smaller than min_replica_count, it
+ will automatically be increased to be min_replica_count.
+ enable_access_logging (bool):
+ Optional. If true, private endpoint's access
+ logs are sent to StackDriver Logging.
+ These logs are like standard server access logs,
+ containing information like timestamp and
+ latency for each MatchRequest.
+ Note that Stackdriver logs may incur a cost,
+ especially if the deployed index receives a high
+ queries per second rate (QPS). Estimate your
+ costs before enabling this option.
+ deployed_index_auth_config (google.cloud.aiplatform_v1.types.DeployedIndexAuthConfig):
+ Optional. If set, the authentication is
+ enabled for the private endpoint.
+ reserved_ip_ranges (Sequence[str]):
+ Optional. A list of reserved ip ranges under
+ the VPC network that can be used for this
+ DeployedIndex.
+ If set, we will deploy the index within the
+ provided ip ranges. Otherwise, the index might
+ be deployed to any ip ranges under the provided
+ VPC network.
+
+ The value sohuld be the name of the address
+ (https://cloud.google.com/compute/docs/reference/rest/v1/addresses)
+ Example: 'vertex-ai-ip-range'.
+ deployment_group (str):
+ Optional. The deployment group can be no longer than 64
+ characters (eg: 'test', 'prod'). If not set, we will use the
+ 'default' deployment group.
+
+ Creating ``deployment_groups`` with ``reserved_ip_ranges``
+ is a recommended practice when the peered network has
+ multiple peering ranges. This creates your deployments from
+ predictable IP spaces for easier traffic administration.
+ Also, one deployment_group (except 'default') can only be
+ used with the same reserved_ip_ranges which means if the
+ deployment_group has been used with reserved_ip_ranges: [a,
+ b, c], using it with [a, b] or [d, e] is disallowed.
+
+ Note: we only support up to 5 deployment groups(not
+ including 'default').
+ auth_config_audiences (Sequence[str]):
+ Optional. The list of JWT
+ `audiences `__.
+ that are allowed to access. A JWT containing any of these
+ audiences will be accepted.
+ auth_config_allowed_issuers (Sequence[str]):
+ Optional. A list of allowed JWT issuers. Each entry must be a valid
+ Google service account, in the following format:
+
+ ``service-account-name@project-id.iam.gserviceaccount.com``
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ """
+
+ deployed_index = gca_matching_engine_index_endpoint.DeployedIndex(
+ id=deployed_index_id,
+ index=index_resource_name,
+ display_name=display_name,
+ enable_access_logging=enable_access_logging,
+ reserved_ip_ranges=reserved_ip_ranges,
+ deployment_group=deployment_group,
+ )
+
+ if auth_config_audiences and auth_config_allowed_issuers:
+ deployed_index.deployed_index_auth_config = gca_matching_engine_index_endpoint.DeployedIndexAuthConfig(
+ auth_provider=gca_matching_engine_index_endpoint.DeployedIndexAuthConfig.AuthProvider(
+ audiences=auth_config_audiences,
+ allowed_issuers=auth_config_allowed_issuers,
+ )
+ )
+
+ if machine_type:
+ machine_spec = gca_machine_resources_compat.MachineSpec(
+ machine_type=machine_type
+ )
+
+ deployed_index.dedicated_resources = (
+ gca_machine_resources_compat.DedicatedResources(
+ machine_spec=machine_spec,
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ )
+ )
+
+ else:
+ deployed_index.automatic_resources = (
+ gca_machine_resources_compat.AutomaticResources(
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ )
+ )
+ return deployed_index
+
+ def deploy_index(
+ self,
+ index: matching_engine.MatchingEngineIndex,
+ deployed_index_id: str,
+ display_name: Optional[str] = None,
+ machine_type: Optional[str] = None,
+ min_replica_count: Optional[int] = None,
+ max_replica_count: Optional[int] = None,
+ enable_access_logging: Optional[bool] = None,
+ reserved_ip_ranges: Optional[Sequence[str]] = None,
+ deployment_group: Optional[str] = None,
+ auth_config_audiences: Optional[Sequence[str]] = None,
+ auth_config_allowed_issuers: Optional[Sequence[str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ) -> "MatchingEngineIndexEndpoint":
+ """Deploys an existing index resource to this endpoint resource.
+
+ Args:
+ index (MatchingEngineIndex):
+ Required. The Index this is the
+ deployment of. We may refer to this Index as the
+ DeployedIndex's "original" Index.
+ deployed_index_id (str):
+ Required. The user specified ID of the
+ DeployedIndex. The ID can be up to 128
+ characters long and must start with a letter and
+ only contain letters, numbers, and underscores.
+ The ID must be unique within the project it is
+ created in.
+ display_name (str):
+ The display name of the DeployedIndex. If not provided upon
+ creation, the Index's display_name is used.
+ machine_type (str):
+ Optional. The type of machine. Not specifying machine type will
+ result in model to be deployed with automatic resources.
+ min_replica_count (int):
+ Optional. The minimum number of machine replicas this deployed
+ model will be always deployed on. If traffic against it increases,
+ it may dynamically be deployed onto more replicas, and as traffic
+ decreases, some of these extra replicas may be freed.
+
+ If this value is not provided, the value of 2 will be used.
+ max_replica_count (int):
+ Optional. The maximum number of replicas this deployed model may
+ be deployed on when the traffic against it increases. If requested
+ value is too large, the deployment will error, but if deployment
+ succeeds then the ability to scale the model to that many replicas
+ is guaranteed (barring service outages). If traffic against the
+ deployed model increases beyond what its replicas at maximum may
+ handle, a portion of the traffic will be dropped. If this value
+ is not provided, the larger value of min_replica_count or 2 will
+ be used. If value provided is smaller than min_replica_count, it
+ will automatically be increased to be min_replica_count.
+ enable_access_logging (bool):
+ Optional. If true, private endpoint's access
+ logs are sent to StackDriver Logging.
+ These logs are like standard server access logs,
+ containing information like timestamp and
+ latency for each MatchRequest.
+ Note that Stackdriver logs may incur a cost,
+ especially if the deployed index receives a high
+ queries per second rate (QPS). Estimate your
+ costs before enabling this option.
+ reserved_ip_ranges (Sequence[str]):
+ Optional. A list of reserved ip ranges under
+ the VPC network that can be used for this
+ DeployedIndex.
+ If set, we will deploy the index within the
+ provided ip ranges. Otherwise, the index might
+ be deployed to any ip ranges under the provided
+ VPC network.
+
+ The value sohuld be the name of the address
+ (https://cloud.google.com/compute/docs/reference/rest/v1/addresses)
+ Example: 'vertex-ai-ip-range'.
+ deployment_group (str):
+ Optional. The deployment group can be no longer than 64
+ characters (eg: 'test', 'prod'). If not set, we will use the
+ 'default' deployment group.
+
+ Creating ``deployment_groups`` with ``reserved_ip_ranges``
+ is a recommended practice when the peered network has
+ multiple peering ranges. This creates your deployments from
+ predictable IP spaces for easier traffic administration.
+ Also, one deployment_group (except 'default') can only be
+ used with the same reserved_ip_ranges which means if the
+ deployment_group has been used with reserved_ip_ranges: [a,
+ b, c], using it with [a, b] or [d, e] is disallowed.
+
+ Note: we only support up to 5 deployment groups(not
+ including 'default').
+ auth_config_audiences (Sequence[str]):
+ The list of JWT
+ `audiences `__.
+ that are allowed to access. A JWT containing any of these
+ audiences will be accepted.
+
+ auth_config_audiences and auth_config_allowed_issuers must be passed together.
+ auth_config_allowed_issuers (Sequence[str]):
+ A list of allowed JWT issuers. Each entry must be a valid
+ Google service account, in the following format:
+
+ ``service-account-name@project-id.iam.gserviceaccount.com``
+
+ auth_config_audiences and auth_config_allowed_issuers must be passed together.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ Returns:
+ MatchingEngineIndexEndpoint - IndexEndpoint resource object
+ """
+
+ self.wait()
+
+ _LOGGER.log_action_start_against_resource(
+ "Deploying index",
+ "index_endpoint",
+ self,
+ )
+
+ deployed_index = self._build_deployed_index(
+ deployed_index_id=deployed_index_id,
+ index_resource_name=index.resource_name,
+ display_name=display_name,
+ machine_type=machine_type,
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ enable_access_logging=enable_access_logging,
+ reserved_ip_ranges=reserved_ip_ranges,
+ deployment_group=deployment_group,
+ auth_config_audiences=auth_config_audiences,
+ auth_config_allowed_issuers=auth_config_allowed_issuers,
+ )
+
+ deploy_lro = self.api_client.deploy_index(
+ index_endpoint=self.resource_name,
+ deployed_index=deployed_index,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Deploy index", "index_endpoint", self.__class__, deploy_lro
+ )
+
+ deploy_lro.result()
+
+ _LOGGER.log_action_completed_against_resource(
+ "index_endpoint", "Deployed index", self
+ )
+
+ # update local resource
+ self._sync_gca_resource()
+
+ return self
+
+ def undeploy_index(
+ self,
+ deployed_index_id: str,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ) -> "MatchingEngineIndexEndpoint":
+ """Undeploy a deployed index endpoint resource.
+
+ Args:
+ deployed_index_id (str):
+ Required. The ID of the DeployedIndex
+ to be undeployed from the IndexEndpoint.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ Returns:
+ MatchingEngineIndexEndpoint - IndexEndpoint resource object
+ """
+
+ self.wait()
+
+ _LOGGER.log_action_start_against_resource(
+ "Undeploying index",
+ "index_endpoint",
+ self,
+ )
+
+ undeploy_lro = self.api_client.undeploy_index(
+ index_endpoint=self.resource_name,
+ deployed_index_id=deployed_index_id,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Undeploy index", "index_endpoint", self.__class__, undeploy_lro
+ )
+
+ undeploy_lro.result()
+
+ _LOGGER.log_action_completed_against_resource(
+ "index_endpoint", "Undeployed index", self
+ )
+
+ return self
+
+ def mutate_deployed_index(
+ self,
+ deployed_index_id: str,
+ min_replica_count: int = 1,
+ max_replica_count: int = 1,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ ):
+ """Updates an existing deployed index under this endpoint resource.
+
+ Args:
+ index_id (str):
+ Required. The ID of the MatchingEnginIndex associated with the DeployedIndex.
+ deployed_index_id (str):
+ Required. The user specified ID of the
+ DeployedIndex. The ID can be up to 128
+ characters long and must start with a letter and
+ only contain letters, numbers, and underscores.
+ The ID must be unique within the project it is
+ created in.
+ min_replica_count (int):
+ Optional. The minimum number of machine replicas this deployed
+ model will be always deployed on. If traffic against it increases,
+ it may dynamically be deployed onto more replicas, and as traffic
+ decreases, some of these extra replicas may be freed.
+ max_replica_count (int):
+ Optional. The maximum number of replicas this deployed model may
+ be deployed on when the traffic against it increases. If requested
+ value is too large, the deployment will error, but if deployment
+ succeeds then the ability to scale the model to that many replicas
+ is guaranteed (barring service outages). If traffic against the
+ deployed model increases beyond what its replicas at maximum may
+ handle, a portion of the traffic will be dropped. If this value
+ is not provided, the larger value of min_replica_count or 1 will
+ be used. If value provided is smaller than min_replica_count, it
+ will automatically be increased to be min_replica_count.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ """
+
+ self.wait()
+
+ _LOGGER.log_action_start_against_resource(
+ "Mutating index",
+ "index_endpoint",
+ self,
+ )
+
+ deployed_index = self._build_deployed_index(
+ index_resource_name=None,
+ deployed_index_id=deployed_index_id,
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ )
+
+ deploy_lro = self.api_client.mutate_deployed_index(
+ index_endpoint=self.resource_name,
+ deployed_index=deployed_index,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Mutate index", "index_endpoint", self.__class__, deploy_lro
+ )
+
+ deploy_lro.result()
+
+ # update local resource
+ self._sync_gca_resource()
+
+ _LOGGER.log_action_completed_against_resource("index_endpoint", "Mutated", self)
+
+ return self
+
+ @property
+ def deployed_indexes(
+ self,
+ ) -> List[gca_matching_engine_index_endpoint.DeployedIndex]:
+ """Returns a list of deployed indexes on this endpoint.
+
+ Returns:
+ List[gca_matching_engine_index_endpoint.DeployedIndex] - Deployed indexes
+ """
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.deployed_indexes
+
+ @base.optional_sync()
+ def _undeploy(
+ self,
+ deployed_index_id: str,
+ metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ sync=True,
+ ) -> None:
+ """Undeploys a deployed index.
+
+ Args:
+ deployed_index_id (str):
+ Required. The ID of the DeployedIndex to be undeployed from the
+ Endpoint.
+ metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as
+ metadata.
+ """
+ self._sync_gca_resource()
+
+ _LOGGER.log_action_start_against_resource("Undeploying", "index_endpoint", self)
+
+ operation_future = self.api_client.undeploy_index(
+ index_endpoint=self.resource_name,
+ deployed_index_id=deployed_index_id,
+ metadata=metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Undeploy", "index_endpoint", self.__class__, operation_future
+ )
+
+ # block before returning
+ operation_future.result()
+
+ # update local resource
+ self._sync_gca_resource()
+
+ _LOGGER.log_action_completed_against_resource(
+ "index_endpoint", "undeployed", self
+ )
+
+ def undeploy_all(self, sync: bool = True) -> "MatchingEngineIndexEndpoint":
+ """Undeploys every index deployed to this MatchingEngineIndexEndpoint.
+
+ Args:
+ sync (bool):
+ Whether to execute this method synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ """
+ self._sync_gca_resource()
+
+ for deployed_index in self.deployed_indexes:
+ self._undeploy(deployed_index_id=deployed_index.id, sync=sync)
+
+ return self
+
+ def delete(self, force: bool = False, sync: bool = True) -> None:
+ """Deletes this MatchingEngineIndexEndpoint resource. If force is set to True,
+ all indexes on this endpoint will be undeployed prior to deletion.
+
+ Args:
+ force (bool):
+ Required. If force is set to True, all deployed indexes on this
+ endpoint will be undeployed first. Default is False.
+ sync (bool):
+ Whether to execute this method synchronously. If False, this method
+ will be executed in concurrent Future and any downstream object will
+ be immediately returned and synced when the Future has completed.
+ Raises:
+ FailedPrecondition: If indexes are deployed on this MatchingEngineIndexEndpoint and force = False.
+ """
+ if force:
+ self.undeploy_all(sync=sync)
+
+ super().delete(sync=sync)
+
+ @property
+ def description(self) -> str:
+ """Description of the index endpoint."""
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.description
+
+ def match(
+ self, deployed_index_id: str, queries: List[List[float]], num_neighbors: int = 1
+ ) -> List[List[MatchNeighbor]]:
+ """Retrieves nearest neighbors for the given embedding queries on the specified deployed index.
+
+ Args:
+ deployed_index_id (str):
+ Required. The ID of the DeployedIndex to match the queries against.
+ queries (List[List[float]]):
+ Required. A list of queries. Each query is a list of floats, representing a single embedding.
+ num_neighbors (int):
+ Required. The number of nearest neighbors to be retrieved from database for
+ each query.
+
+ Returns:
+ List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
+ """
+
+ # Find the deployed index by id
+ deployed_indexes = [
+ deployed_index
+ for deployed_index in self.deployed_indexes
+ if deployed_index.id == deployed_index_id
+ ]
+
+ if not deployed_indexes:
+ raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
+
+ # Retrieve server ip from deployed index
+ server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
+
+ # Set up channel and stub
+ channel = grpc.insecure_channel("{}:10000".format(server_ip))
+ stub = match_service_pb2_grpc.MatchServiceStub(channel)
+
+ # Create the batch match request
+ batch_request = match_service_pb2.BatchMatchRequest()
+ batch_request_for_index = (
+ match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex()
+ )
+ batch_request_for_index.deployed_index_id = deployed_index_id
+ batch_request_for_index.requests.extend(
+ [
+ match_service_pb2.MatchRequest(
+ num_neighbors=num_neighbors,
+ deployed_index_id=deployed_index_id,
+ float_val=query,
+ )
+ for query in queries
+ ]
+ )
+ batch_request.requests.append(batch_request_for_index)
+
+ # Perform the request
+ response = stub.BatchMatch(batch_request)
+
+ # Wrap the results in MatchNeighbor objects and return
+ return [
+ [
+ MatchNeighbor(id=neighbor.id, distance=neighbor.distance)
+ for neighbor in embedding_neighbors.neighbor
+ ]
+ for embedding_neighbors in response.responses[0].responses
+ ]
diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py
index b3ef6e09a2..65ee2cb92b 100644
--- a/google/cloud/aiplatform/metadata/artifact.py
+++ b/google/cloud/aiplatform/metadata/artifact.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,21 +15,76 @@
# limitations under the License.
#
-from typing import Optional, Dict
+from typing import Optional, Dict, Union
import proto
+from google.auth import credentials as auth_credentials
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.metadata.resource import _Resource
-from google.cloud.aiplatform_v1beta1 import ListArtifactsRequest
-from google.cloud.aiplatform_v1beta1.types import artifact as gca_artifact
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.compat.types import (
+ metadata_service as gca_metadata_service,
+)
+from google.cloud.aiplatform.metadata import metadata_store
+from google.cloud.aiplatform.metadata import resource
+from google.cloud.aiplatform.metadata import utils as metadata_utils
+from google.cloud.aiplatform.metadata.schema import base_artifact
+from google.cloud.aiplatform.utils import rest_utils
+
+_LOGGER = base.Logger(__name__)
-class _Artifact(_Resource):
+
+class Artifact(resource._Resource):
"""Metadata Artifact resource for Vertex AI"""
+ def __init__(
+ self,
+ artifact_name: str,
+ *,
+ metadata_store_id: str = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing Metadata Artifact given a resource name or ID.
+
+ Args:
+ artifact_name (str):
+ Required. A fully-qualified resource name or resource ID of the Artifact.
+ Example: "projects/123/locations/us-central1/metadataStores/default/artifacts/my-resource".
+ or "my-resource" when project and location are initialized or passed.
+ metadata_store_id (str):
+ Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
+ If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
+ project (str):
+ Optional. Project to retrieve the artifact from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve the Artifact from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Artifact. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ resource_name=artifact_name,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
_resource_noun = "artifacts"
_getter_method = "get_artifact"
+ _delete_method = "delete_artifact"
+ _parse_resource_name_method = "parse_artifact_path"
+ _format_resource_name_method = "artifact_path"
+ _list_method = "list_artifacts"
@classmethod
def _create_resource(
@@ -38,25 +93,125 @@ def _create_resource(
parent: str,
resource_id: str,
schema_title: str,
+ uri: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
- ) -> proto.Message:
+ state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
+ ) -> gca_artifact.Artifact:
gapic_artifact = gca_artifact.Artifact(
+ uri=uri,
schema_title=schema_title,
schema_version=schema_version,
display_name=display_name,
description=description,
metadata=metadata if metadata else {},
+ state=state,
)
return client.create_artifact(
- parent=parent, artifact=gapic_artifact, artifact_id=resource_id,
+ parent=parent,
+ artifact=gapic_artifact,
+ artifact_id=resource_id,
)
+ @classmethod
+ def _create(
+ cls,
+ resource_id: str,
+ schema_title: str,
+ uri: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Artifact":
+ """Creates a new Metadata resource.
+
+ Args:
+ resource_id (str):
+ Required. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores///.
+ schema_title (str):
+ Required. schema_title identifies the schema title used by the resource.
+ display_name (str):
+ Optional. The user-defined name of the resource.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the resource.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the resource to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the resource.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ metadata_store_id (str):
+ The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores///
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Project used to create this resource. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Location used to create this resource. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Custom credentials used to create this resource. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ resource (_Resource):
+ Instantiated representation of the managed Metadata resource.
+
+ """
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ parent = utils.full_resource_name(
+ resource_name=metadata_store_id,
+ resource_noun=metadata_store._MetadataStore._resource_noun,
+ parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
+ format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
+ project=project,
+ location=location,
+ )
+
+ resource = cls._create_resource(
+ client=api_client,
+ parent=parent,
+ resource_id=resource_id,
+ schema_title=schema_title,
+ uri=uri,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=metadata,
+ state=state,
+ )
+
+ self = cls._empty_constructor(
+ project=project, location=location, credentials=credentials
+ )
+ self._gca_resource = resource
+
+ return self
+
@classmethod
def _update_resource(
- cls, client: utils.MetadataClientWithOverride, resource: proto.Message,
+ cls,
+ client: utils.MetadataClientWithOverride,
+ resource: proto.Message,
) -> proto.Message:
"""Update Artifacts with given input.
@@ -86,5 +241,326 @@ def _list_resources(
filter (str):
Optional. filter string to restrict the list result
"""
- list_request = ListArtifactsRequest(parent=parent, filter=filter,)
+ list_request = gca_metadata_service.ListArtifactsRequest(
+ parent=parent,
+ filter=filter,
+ )
return client.list_artifacts(request=list_request)
+
+ @classmethod
+ def create(
+ cls,
+ schema_title: str,
+ *,
+ resource_id: Optional[str] = None,
+ uri: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: gca_artifact.Artifact.State = gca_artifact.Artifact.State.LIVE,
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Artifact":
+ """Creates a new Metadata Artifact.
+
+ Args:
+ schema_title (str):
+ Required. schema_title identifies the schema title used by the Artifact.
+
+ Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
+ resource_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Artifact. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Artifact. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Artifact. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ Artifact: Instantiated representation of the managed Metadata Artifact.
+ """
+ return cls._create(
+ resource_id=resource_id,
+ schema_title=schema_title,
+ uri=uri,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=metadata,
+ state=state,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def create_from_base_artifact_schema(
+ cls,
+ *,
+ base_artifact_schema: "base_artifact.BaseArtifactSchema",
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Artifact":
+ """Creates a new Metadata Artifact from a BaseArtifactSchema class instance.
+
+ Args:
+ base_artifact_schema (BaseArtifactSchema):
+ Required. An instance of the BaseArtifactType class that can be
+ provided instead of providing artifact specific parameters.
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Artifact. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Artifact. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Artifact. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ Artifact: Instantiated representation of the managed Metadata Artifact.
+ """
+
+ return cls.create(
+ resource_id=base_artifact_schema.artifact_id,
+ schema_title=base_artifact_schema.schema_title,
+ uri=base_artifact_schema.uri,
+ display_name=base_artifact_schema.display_name,
+ schema_version=base_artifact_schema.schema_version,
+ description=base_artifact_schema.description,
+ metadata=base_artifact_schema.metadata,
+ state=base_artifact_schema.state,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ @property
+ def uri(self) -> Optional[str]:
+ "Uri for this Artifact."
+ return self.gca_resource.uri
+
+ @classmethod
+ def get_with_uri(
+ cls,
+ uri: str,
+ *,
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Artifact":
+ """Get an Artifact by it's uri.
+
+ If more than one Artifact with this uri is in the metadata store then the Artifact with the latest
+ create_time is returned.
+
+ Args:
+ uri(str):
+ Required. Uri of the Artifact to retrieve.
+ metadata_store_id (str):
+ Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
+ If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
+ project (str):
+ Optional. Project to retrieve the artifact from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve the Artifact from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Artifact. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ Artifact: Artifact with given uri.
+ Raises:
+ ValueError: If no Artifact exists with the provided uri.
+
+ """
+
+ matched_artifacts = cls.list(
+ filter=f'uri = "{uri}"',
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ if not matched_artifacts:
+ raise ValueError(
+ f"No artifact with uri {uri} is in the `{metadata_store_id}` MetadataStore."
+ )
+
+ if len(matched_artifacts) > 1:
+ matched_artifacts.sort(key=lambda a: a.create_time, reverse=True)
+ resource_names = "\n".join(a.resource_name for a in matched_artifacts)
+ _LOGGER.warn(
+ f"Mutiple artifacts with uri {uri} were found: {resource_names}"
+ )
+ _LOGGER.warn(f"Returning {matched_artifacts[0].resource_name}")
+
+ return matched_artifacts[0]
+
+ @property
+ def lineage_console_uri(self) -> str:
+ """Cloud console uri to view this Artifact Lineage."""
+ metadata_store = self._parse_resource_name(self.resource_name)["metadata_store"]
+ return f"https://console.cloud.google.com/vertex-ai/locations/{self.location}/metadata-stores/{metadata_store}/artifacts/{self.name}?project={self.project}"
+
+ def __repr__(self) -> str:
+ if self._gca_resource:
+ return f"{object.__repr__(self)} \nresource name: {self.resource_name}\nuri: {self.uri}\nschema_title:{self.gca_resource.schema_title}"
+
+ return base.FutureManager.__repr__(self)
+
+
+class _VertexResourceArtifactResolver:
+
+ # TODO(b/235594717) Add support for managed datasets
+ _resource_to_artifact_type = {models.Model: "google.VertexModel"}
+
+ @classmethod
+ def supports_metadata(cls, resource: base.VertexAiResourceNoun) -> bool:
+ """Returns True if Vertex resource is supported in Vertex Metadata otherwise False.
+
+ Args:
+ resource (base.VertexAiResourceNoun):
+ Requried. Instance of Vertex AI Resource.
+ Returns:
+ True if Vertex resource is supported in Vertex Metadata otherwise False.
+ """
+ return type(resource) in cls._resource_to_artifact_type
+
+ @classmethod
+ def validate_resource_supports_metadata(cls, resource: base.VertexAiResourceNoun):
+ """Validates Vertex resource is supported in Vertex Metadata.
+
+ Args:
+ resource (base.VertexAiResourceNoun):
+ Required. Instance of Vertex AI Resource.
+ Raises:
+ ValueError: If Vertex AI Resource is not support in Vertex Metadata.
+ """
+ if not cls.supports_metadata(resource):
+ raise ValueError(
+ f"Vertex {type(resource)} is not yet supported in Vertex Metadata."
+ f"Only {list(cls._resource_to_artifact_type.keys())} are supported"
+ )
+
+ @classmethod
+ def resolve_vertex_resource(
+ cls, resource: Union[models.Model]
+ ) -> Optional[Artifact]:
+ """Resolves Vertex Metadata Artifact that represents this Vertex Resource.
+
+ If there are multiple Artifacts in the metadata store that represent the provided resource. The one with the
+ latest create_time is returned.
+
+ Args:
+ resource (base.VertexAiResourceNoun):
+ Required. Instance of Vertex AI Resource.
+ Returns:
+ Artifact: Artifact that represents this Vertex Resource. None if Resource not found in Metadata store.
+ """
+ cls.validate_resource_supports_metadata(resource)
+ resource.wait()
+ metadata_type = cls._resource_to_artifact_type[type(resource)]
+ uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
+
+ artifacts = Artifact.list(
+ filter=metadata_utils._make_filter_string(
+ schema_title=metadata_type,
+ uri=uri,
+ ),
+ project=resource.project,
+ location=resource.location,
+ credentials=resource.credentials,
+ )
+
+ artifacts.sort(key=lambda a: a.create_time, reverse=True)
+ if artifacts:
+ # most recent
+ return artifacts[0]
+
+ @classmethod
+ def create_vertex_resource_artifact(cls, resource: Union[models.Model]) -> Artifact:
+ """Creates Vertex Metadata Artifact that represents this Vertex Resource.
+
+ Args:
+ resource (base.VertexAiResourceNoun):
+ Required. Instance of Vertex AI Resource.
+ Returns:
+ Artifact: Artifact that represents this Vertex Resource.
+ """
+ cls.validate_resource_supports_metadata(resource)
+ resource.wait()
+ metadata_type = cls._resource_to_artifact_type[type(resource)]
+ uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
+
+ return Artifact.create(
+ schema_title=metadata_type,
+ display_name=getattr(resource.gca_resource, "display_name", None),
+ uri=uri,
+ metadata={"resourceName": resource.resource_name},
+ project=resource.project,
+ location=resource.location,
+ credentials=resource.credentials,
+ )
+
+ @classmethod
+ def resolve_or_create_resource_artifact(
+ cls, resource: Union[models.Model]
+ ) -> Artifact:
+ """Create of gets Vertex Metadata Artifact that represents this Vertex Resource.
+
+ Args:
+ resource (base.VertexAiResourceNoun):
+ Required. Instance of Vertex AI Resource.
+ Returns:
+ Artifact: Artifact that represents this Vertex Resource.
+ """
+ artifact = cls.resolve_vertex_resource(resource=resource)
+ if artifact:
+ return artifact
+ return cls.create_vertex_resource_artifact(resource=resource)
diff --git a/google/cloud/aiplatform/metadata/constants.py b/google/cloud/aiplatform/metadata/constants.py
index 62e7d6e075..8776aef3b1 100644
--- a/google/cloud/aiplatform/metadata/constants.py
+++ b/google/cloud/aiplatform/metadata/constants.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,21 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from google.cloud.aiplatform.compat.types import artifact
SYSTEM_RUN = "system.Run"
SYSTEM_EXPERIMENT = "system.Experiment"
+SYSTEM_EXPERIMENT_RUN = "system.ExperimentRun"
SYSTEM_PIPELINE = "system.Pipeline"
+SYSTEM_PIPELINE_RUN = "system.PipelineRun"
SYSTEM_METRICS = "system.Metrics"
+_EXPERIMENTS_V2_TENSORBOARD_RUN = "google.VertexTensorboardRun"
+
_DEFAULT_SCHEMA_VERSION = "0.0.1"
SCHEMA_VERSIONS = {
SYSTEM_RUN: _DEFAULT_SCHEMA_VERSION,
SYSTEM_EXPERIMENT: _DEFAULT_SCHEMA_VERSION,
+ SYSTEM_EXPERIMENT_RUN: _DEFAULT_SCHEMA_VERSION,
SYSTEM_PIPELINE: _DEFAULT_SCHEMA_VERSION,
SYSTEM_METRICS: _DEFAULT_SCHEMA_VERSION,
}
-# The EXPERIMENT_METADATA is needed until we support context deletion in backend service.
-# TODO: delete EXPERIMENT_METADATA once backend supports context deletion.
+_BACKING_TENSORBOARD_RESOURCE_KEY = "backing_tensorboard_resource"
+
+
+_PARAM_KEY = "_params"
+_METRIC_KEY = "_metrics"
+_STATE_KEY = "_state"
+
+_PARAM_PREFIX = "param"
+_METRIC_PREFIX = "metric"
+_TIME_SERIES_METRIC_PREFIX = "time_series_metric"
+
+# This is currently used to filter in the Console.
EXPERIMENT_METADATA = {"experiment_deleted": False}
+
+PIPELINE_PARAM_PREFIX = "input:"
+
+TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD = "tensorboard_link"
+
+GCP_ARTIFACT_RESOURCE_NAME_KEY = "resourceName"
+
+# constant to mark an Experiment context as originating from the SDK
+# TODO(b/235593750) Remove this field
+_VERTEX_EXPERIMENT_TRACKING_LABEL = "vertex_experiment_tracking"
+
+_TENSORBOARD_RUN_REFERENCE_ARTIFACT = artifact.Artifact(
+ name="google-vertex-tensorboard-run-v0-0-1",
+ schema_title=_EXPERIMENTS_V2_TENSORBOARD_RUN,
+ schema_version="0.0.1",
+ metadata={_VERTEX_EXPERIMENT_TRACKING_LABEL: True},
+)
diff --git a/google/cloud/aiplatform/metadata/context.py b/google/cloud/aiplatform/metadata/context.py
index ddd583bbdf..d072a6e047 100644
--- a/google/cloud/aiplatform/metadata/context.py
+++ b/google/cloud/aiplatform/metadata/context.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,21 +15,39 @@
# limitations under the License.
#
-from typing import Optional, Dict, Sequence
+from typing import Optional, Dict, List, Sequence
import proto
+from google.cloud.aiplatform import base
from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.metadata.resource import _Resource
-from google.cloud.aiplatform_v1beta1 import ListContextsRequest
-from google.cloud.aiplatform_v1beta1.types import context as gca_context
-
-
-class _Context(_Resource):
+from google.cloud.aiplatform.metadata import utils as metadata_utils
+from google.cloud.aiplatform.compat.types import context as gca_context
+from google.cloud.aiplatform.compat.types import (
+ lineage_subgraph as gca_lineage_subgraph,
+)
+from google.cloud.aiplatform.compat.types import (
+ metadata_service as gca_metadata_service,
+)
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import execution
+from google.cloud.aiplatform.metadata import resource
+
+
+class _Context(resource._Resource):
"""Metadata Context resource for Vertex AI"""
_resource_noun = "contexts"
_getter_method = "get_context"
+ _delete_method = "delete_context"
+ _parse_resource_name_method = "parse_context_path"
+ _format_resource_name_method = "context_path"
+ _list_method = "list_contexts"
+
+ @property
+ def parent_contexts(self) -> Sequence[str]:
+ """The parent context resource names of this context."""
+ return self.gca_resource.parent_contexts
def add_artifacts_and_executions(
self,
@@ -50,6 +68,19 @@ def add_artifacts_and_executions(
executions=execution_resource_names,
)
+ def get_artifacts(self) -> List[artifact.Artifact]:
+ """Returns all Artifact attributed to this Context.
+
+ Returns:
+ artifacts(List[Artifacts]): All Artifacts under this context.
+ """
+ return artifact.Artifact.list(
+ filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+
@classmethod
def _create_resource(
cls,
@@ -70,12 +101,16 @@ def _create_resource(
metadata=metadata if metadata else {},
)
return client.create_context(
- parent=parent, context=gapic_context, context_id=resource_id,
+ parent=parent,
+ context=gapic_context,
+ context_id=resource_id,
)
@classmethod
def _update_resource(
- cls, client: utils.MetadataClientWithOverride, resource: proto.Message,
+ cls,
+ client: utils.MetadataClientWithOverride,
+ resource: proto.Message,
) -> proto.Message:
"""Update Contexts with given input.
@@ -106,5 +141,43 @@ def _list_resources(
Optional. filter string to restrict the list result
"""
- list_request = ListContextsRequest(parent=parent, filter=filter,)
+ list_request = gca_metadata_service.ListContextsRequest(
+ parent=parent,
+ filter=filter,
+ )
return client.list_contexts(request=list_request)
+
+ def add_context_children(self, contexts: List["_Context"]):
+ """Adds the provided contexts as children of this context.
+
+ Args:
+ contexts (List[_Context]): Contexts to add as children.
+ """
+ self.api_client.add_context_children(
+ context=self.resource_name,
+ child_contexts=[c.resource_name for c in contexts],
+ )
+
+ def query_lineage_subgraph(self) -> gca_lineage_subgraph.LineageSubgraph:
+ """Queries lineage subgraph of this context.
+
+ Returns:
+ lineage subgraph(gca_lineage_subgraph.LineageSubgraph): Lineage subgraph of this Context.
+ """
+
+ return self.api_client.query_context_lineage_subgraph(
+ context=self.resource_name, retry=base._DEFAULT_RETRY
+ )
+
+ def get_executions(self) -> List[execution.Execution]:
+ """Returns Executions associated to this context.
+
+ Returns:
+ executions (List[Executions]): Executions associated to this context.
+ """
+ return execution.Execution.list(
+ filter=metadata_utils._make_filter_string(in_context=[self.resource_name]),
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py
index 3605efdb4f..895417fc64 100644
--- a/google/cloud/aiplatform/metadata/execution.py
+++ b/google/cloud/aiplatform/metadata/execution.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,92 +14,400 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
-from typing import Optional, Dict, Sequence
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Union
import proto
-from google.api_core import exceptions
+from google.auth import credentials as auth_credentials
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
-from google.cloud.aiplatform.metadata.artifact import _Artifact
-from google.cloud.aiplatform.metadata.resource import _Resource
-from google.cloud.aiplatform_v1beta1 import Event
-from google.cloud.aiplatform_v1beta1.types import execution as gca_execution
-from google.cloud.aiplatform_v1beta1.types.metadata_service import ListExecutionsRequest
+from google.cloud.aiplatform.compat.types import event as gca_event
+from google.cloud.aiplatform.compat.types import execution as gca_execution
+from google.cloud.aiplatform.compat.types import (
+ metadata_service as gca_metadata_service,
+)
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import metadata_store
+from google.cloud.aiplatform.metadata import resource
+from google.cloud.aiplatform.metadata.schema import base_execution
-class _Execution(_Resource):
+class Execution(resource._Resource):
"""Metadata Execution resource for Vertex AI"""
_resource_noun = "executions"
_getter_method = "get_execution"
+ _delete_method = "delete_execution"
+ _parse_resource_name_method = "parse_execution_path"
+ _format_resource_name_method = "execution_path"
+ _list_method = "list_executions"
+
+ def __init__(
+ self,
+ execution_name: str,
+ *,
+ metadata_store_id: str = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing Metadata Execution given a resource name or ID.
+
+ Args:
+ execution_name (str):
+ Required. A fully-qualified resource name or resource ID of the Execution.
+ Example: "projects/123/locations/us-central1/metadataStores/default/executions/my-resource".
+ or "my-resource" when project and location are initialized or passed.
+ metadata_store_id (str):
+ Optional. MetadataStore to retrieve Execution from. If not set, metadata_store_id is set to "default".
+ If execution_name is a fully-qualified resource, its metadata_store_id overrides this one.
+ project (str):
+ Optional. Project to retrieve the artifact from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve the Execution from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Execution. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ resource_name=execution_name,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ @property
+ def state(self) -> gca_execution.Execution.State:
+ """State of this Execution."""
+ return self._gca_resource.state
+
+ @classmethod
+ def create(
+ cls,
+ schema_title: str,
+ *,
+ state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
+ resource_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ description: Optional[str] = None,
+ metadata_store_id: str = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials=Optional[auth_credentials.Credentials],
+ ) -> "Execution":
+ """
+ Creates a new Metadata Execution.
+
+ Args:
+ schema_title (str):
+ Required. schema_title identifies the schema title used by the Execution.
+ state (gca_execution.Execution.State.RUNNING):
+ Optional. State of this Execution. Defaults to RUNNING.
+ resource_id (str):
+ Optional. The portion of the Execution name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Execution. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Execution. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Execution. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ Execution: Instantiated representation of the managed Metadata Execution.
+
+ """
+ self = cls._empty_constructor(
+ project=project, location=location, credentials=credentials
+ )
+ super(base.VertexAiResourceNounWithFutureManager, self).__init__()
+
+ resource = Execution._create_resource(
+ client=self.api_client,
+ parent=metadata_store._MetadataStore._format_resource_name(
+ project=self.project,
+ location=self.location,
+ metadata_store=metadata_store_id,
+ ),
+ schema_title=schema_title,
+ resource_id=resource_id,
+ metadata=metadata,
+ description=description,
+ display_name=display_name,
+ schema_version=schema_version,
+ state=state,
+ )
+ self._gca_resource = resource
+
+ return self
+
+ @classmethod
+ def create_from_base_execution_schema(
+ cls,
+ *,
+ base_execution_schema: "base_execution.BaseExecutionSchema",
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Execution":
+ """
+ Creates a new Metadata Execution.
+
+ Args:
+ base_execution_schema (BaseExecutionSchema):
+ An instance of the BaseExecutionSchema class that can be
+ provided instead of providing schema specific parameters.
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Execution. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Execution. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Execution. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ Execution: Instantiated representation of the managed Metadata Execution.
+
+ """
+ resource = Execution.create(
+ state=base_execution_schema.state,
+ schema_title=base_execution_schema.schema_title,
+ resource_id=base_execution_schema.execution_id,
+ display_name=base_execution_schema.display_name,
+ schema_version=base_execution_schema.schema_version,
+ metadata=base_execution_schema.metadata,
+ description=base_execution_schema.description,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ return resource
+
+ def __enter__(self):
+ if self.state is not gca_execution.Execution.State.RUNNING:
+ self.update(state=gca_execution.Execution.State.RUNNING)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ state = (
+ gca_execution.Execution.State.FAILED
+ if exc_type
+ else gca_execution.Execution.State.COMPLETE
+ )
+ self.update(state=state)
+
+ def assign_input_artifacts(
+ self, artifacts: List[Union[artifact.Artifact, models.Model]]
+ ):
+ """Assigns Artifacts as inputs to this Executions.
+
+ Args:
+ artifacts (List[Union[artifact.Artifact, models.Model]]):
+ Required. Artifacts to assign as input.
+ """
+ self._add_artifact(artifacts=artifacts, input=True)
+
+ def assign_output_artifacts(
+ self, artifacts: List[Union[artifact.Artifact, models.Model]]
+ ):
+ """Assigns Artifacts as outputs to this Executions.
- def add_artifact(
- self, artifact_resource_name: str, input: bool,
+ Args:
+ artifacts (List[Union[artifact.Artifact, models.Model]]):
+ Required. Artifacts to assign as input.
+ """
+ self._add_artifact(artifacts=artifacts, input=False)
+
+ def _add_artifact(
+ self,
+ artifacts: List[Union[artifact.Artifact, models.Model]],
+ input: bool,
):
"""Connect Artifact to a given Execution.
Args:
- artifact_resource_name (str):
+ artifact_resource_names (List[str]):
Required. The full resource name of the Artifact to connect to the Execution through an Event.
input (bool)
Required. Whether Artifact is an input event to the Execution or not.
"""
- event = Event(
- artifact=artifact_resource_name,
- type_=Event.Type.INPUT if input else Event.Type.OUTPUT,
- )
+ artifact_resource_names = []
+ for a in artifacts:
+ if isinstance(a, artifact.Artifact):
+ artifact_resource_names.append(a.resource_name)
+ else:
+ artifact_resource_names.append(
+ artifact._VertexResourceArtifactResolver.resolve_or_create_resource_artifact(
+ a
+ ).resource_name
+ )
+
+ events = [
+ gca_event.Event(
+ artifact=artifact_resource_name,
+ type_=gca_event.Event.Type.INPUT
+ if input
+ else gca_event.Event.Type.OUTPUT,
+ )
+ for artifact_resource_name in artifact_resource_names
+ ]
self.api_client.add_execution_events(
- execution=self.resource_name, events=[event],
+ execution=self.resource_name,
+ events=events,
)
- def query_input_and_output_artifacts(self) -> Sequence[_Artifact]:
- """query the input and output artifacts connected to the execution.
+ def _get_artifacts(
+ self, event_type: gca_event.Event.Type
+ ) -> List[artifact.Artifact]:
+ """Get Executions input or output Artifacts.
+ Args:
+ event_type (gca_event.Event.Type):
+ Required. The Event type, input or output.
Returns:
- A Sequence of _Artifacts
+ List of Artifacts.
"""
+ subgraph = self.api_client.query_execution_inputs_and_outputs(
+ execution=self.resource_name
+ )
- try:
- artifacts = self.api_client.query_execution_inputs_and_outputs(
- execution=self.resource_name
- ).artifacts
- except exceptions.NotFound:
- return []
+ artifact_map = {
+ artifact_metadata.name: artifact_metadata
+ for artifact_metadata in subgraph.artifacts
+ }
- return [
- _Artifact(
- resource=artifact,
+ gca_artifacts = [
+ artifact_map[event.artifact]
+ for event in subgraph.events
+ if event.type_ == event_type
+ ]
+
+ artifacts = []
+ for gca_artifact in gca_artifacts:
+ this_artifact = artifact.Artifact._empty_constructor(
project=self.project,
location=self.location,
credentials=self.credentials,
)
- for artifact in artifacts
- ]
+ this_artifact._gca_resource = gca_artifact
+ artifacts.append(this_artifact)
+
+ return artifacts
+
+ def get_input_artifacts(self) -> List[artifact.Artifact]:
+ """Get the input Artifacts of this Execution.
+
+ Returns:
+ List of input Artifacts.
+ """
+ return self._get_artifacts(event_type=gca_event.Event.Type.INPUT)
+
+ def get_output_artifacts(self) -> List[artifact.Artifact]:
+ """Get the output Artifacts of this Execution.
+
+ Returns:
+ List of output Artifacts.
+ """
+ return self._get_artifacts(event_type=gca_event.Event.Type.OUTPUT)
@classmethod
def _create_resource(
cls,
client: utils.MetadataClientWithOverride,
parent: str,
- resource_id: str,
schema_title: str,
+ state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
+ resource_id: Optional[str] = None,
display_name: Optional[str] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
- ) -> proto.Message:
+ ) -> gca_execution.Execution:
+ """
+ Creates a new Metadata Execution.
+
+ Args:
+ client (utils.MetadataClientWithOverride):
+ Required. Instantiated Metadata Service Client.
+ parent (str):
+ Required: MetadataStore parent in which to create this Execution.
+ schema_title (str):
+ Required. schema_title identifies the schema title used by the Execution.
+ state (gca_execution.Execution.State):
+ Optional. State of this Execution. Defaults to RUNNING.
+ resource_id (str):
+ Optional. The {execution} portion of the resource name with the
+ format:
+ ``projects/{project}/locations/{location}/metadataStores/{metadatastore}/executions/{execution}``
+ If not provided, the Execution's ID will be a UUID generated
+ by the service. Must be 4-128 characters in length. Valid
+ characters are ``/[a-z][0-9]-/``. Must be unique across all
+ Executions in the parent MetadataStore. (Otherwise the
+ request will fail with ALREADY_EXISTS, or PERMISSION_DENIED
+ if the caller can't view the preexisting Execution.)
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+
+ Returns:
+ Execution: Instantiated representation of the managed Metadata Execution.
+
+ """
gapic_execution = gca_execution.Execution(
schema_title=schema_title,
schema_version=schema_version,
display_name=display_name,
description=description,
metadata=metadata if metadata else {},
+ state=state,
)
return client.create_execution(
- parent=parent, execution=gapic_execution, execution_id=resource_id,
+ parent=parent,
+ execution=gapic_execution,
+ execution_id=resource_id,
)
@classmethod
@@ -120,12 +428,17 @@ def _list_resources(
Optional. filter string to restrict the list result
"""
- list_request = ListExecutionsRequest(parent=parent, filter=filter,)
+ list_request = gca_metadata_service.ListExecutionsRequest(
+ parent=parent,
+ filter=filter,
+ )
return client.list_executions(request=list_request)
@classmethod
def _update_resource(
- cls, client: utils.MetadataClientWithOverride, resource: proto.Message,
+ cls,
+ client: utils.MetadataClientWithOverride,
+ resource: proto.Message,
) -> proto.Message:
"""Update Executions with given input.
@@ -137,3 +450,30 @@ def _update_resource(
"""
return client.update_execution(execution=resource)
+
+ def update(
+ self,
+ state: Optional[gca_execution.Execution.State] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ ):
+ """Update this Execution.
+
+ Args:
+ state (gca_execution.Execution.State):
+ Optional. State of this Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ metadata (Dict[str, Any):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ """
+
+ gca_resource = deepcopy(self._gca_resource)
+ if state:
+ gca_resource.state = state
+ if description:
+ gca_resource.description = description
+ self._nested_update_metadata(gca_resource=gca_resource, metadata=metadata)
+ self._gca_resource = self._update_resource(
+ self.api_client, resource=gca_resource
+ )
diff --git a/google/cloud/aiplatform/metadata/experiment_resources.py b/google/cloud/aiplatform/metadata/experiment_resources.py
new file mode 100644
index 0000000000..c79de84aa5
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/experiment_resources.py
@@ -0,0 +1,721 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+from dataclasses import dataclass
+import logging
+from typing import Dict, List, NamedTuple, Optional, Union, Tuple, Type
+
+from google.auth import credentials as auth_credentials
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import constants
+from google.cloud.aiplatform.metadata import context
+from google.cloud.aiplatform.metadata import execution
+from google.cloud.aiplatform.metadata import metadata
+from google.cloud.aiplatform.metadata import metadata_store
+from google.cloud.aiplatform.metadata import resource
+from google.cloud.aiplatform.metadata import utils as metadata_utils
+from google.cloud.aiplatform.tensorboard import tensorboard_resource
+
+_LOGGER = base.Logger(__name__)
+
+
+@dataclass
+class _ExperimentRow:
+ """Class for representing a run row in an Experiments Dataframe.
+
+ Attributes:
+ params (Dict[str, Union[float, int, str]]): Optional. The parameters of this run.
+ metrics (Dict[str, Union[float, int, str]]): Optional. The metrics of this run.
+ time_series_metrics (Dict[str, float]): Optional. The latest time series metrics of this run.
+ experiment_run_type (Optional[str]): Optional. The type of this run.
+ name (str): Optional. The name of this run.
+ state (str): Optional. The state of this run.
+ """
+
+ params: Optional[Dict[str, Union[float, int, str]]] = None
+ metrics: Optional[Dict[str, Union[float, int, str]]] = None
+ time_series_metrics: Optional[Dict[str, float]] = None
+ experiment_run_type: Optional[str] = None
+ name: Optional[str] = None
+ state: Optional[str] = None
+
+ def to_dict(self) -> Dict[str, Union[float, int, str]]:
+ """Converts this experiment row into a dictionary.
+
+ Returns:
+ Row as a dictionary.
+ """
+ result = {
+ "run_type": self.experiment_run_type,
+ "run_name": self.name,
+ "state": self.state,
+ }
+ for prefix, field in [
+ (constants._PARAM_PREFIX, self.params),
+ (constants._METRIC_PREFIX, self.metrics),
+ (constants._TIME_SERIES_METRIC_PREFIX, self.time_series_metrics),
+ ]:
+ if field:
+ result.update(
+ {f"{prefix}.{key}": value for key, value in field.items()}
+ )
+ return result
+
+
+class Experiment:
+ """Represents a Vertex AI Experiment resource."""
+
+ def __init__(
+ self,
+ experiment_name: str,
+ *,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """
+
+ ```
+ my_experiment = aiplatform.Experiment('my-experiment')
+ ```
+
+ Args:
+ experiment_name (str): Required. The name or resource name of this experiment.
+
+ Resource name is of the format: projects/123/locations/us-central1/experiments/my-experiment
+ project (str):
+ Optional. Project where this experiment is located. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment is located. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to retrieve this experiment. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ metadata_args = dict(
+ resource_name=experiment_name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ with _SetLoggerLevel(resource):
+ experiment_context = context._Context(**metadata_args)
+ self._validate_experiment_context(experiment_context)
+
+ self._metadata_context = experiment_context
+
+ @staticmethod
+ def _validate_experiment_context(experiment_context: context._Context):
+ """Validates this context is an experiment context.
+
+ Args:
+ experiment_context (context._Context): Metadata context.
+ Raises:
+ ValueError: If Metadata context is not an experiment context or a TensorboardExperiment.
+ """
+ if experiment_context.schema_title != constants.SYSTEM_EXPERIMENT:
+ raise ValueError(
+ f"Experiment name {experiment_context.name} is of type "
+ f"({experiment_context.schema_title}) in this MetadataStore. "
+ f"It must of type {constants.SYSTEM_EXPERIMENT}."
+ )
+ if Experiment._is_tensorboard_experiment(experiment_context):
+ raise ValueError(
+ f"Experiment name {experiment_context.name} is a TensorboardExperiment context "
+ f"and cannot be used as a Vertex AI Experiment."
+ )
+
+ @staticmethod
+ def _is_tensorboard_experiment(context: context._Context) -> bool:
+ """Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
+ return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata
+
+ @property
+ def name(self) -> str:
+ """The name of this experiment."""
+ return self._metadata_context.name
+
+ @classmethod
+ def create(
+ cls,
+ experiment_name: str,
+ *,
+ description: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Experiment":
+ """Creates a new experiment in Vertex AI Experiments.
+
+ ```
+ my_experiment = aiplatform.Experiment.create('my-experiment', description='my description')
+ ```
+
+ Args:
+ experiment_name (str): Required. The name of this experiment.
+ description (str): Optional. Describes this experiment's purpose.
+ project (str):
+ Optional. Project where this experiment will be created. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment will be created. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this experiment. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ The newly created experiment.
+ """
+
+ metadata_store._MetadataStore.ensure_default_metadata_store_exists(
+ project=project, location=location, credentials=credentials
+ )
+
+ with _SetLoggerLevel(resource):
+ experiment_context = context._Context._create(
+ resource_id=experiment_name,
+ display_name=experiment_name,
+ description=description,
+ schema_title=constants.SYSTEM_EXPERIMENT,
+ schema_version=metadata._get_experiment_schema_version(),
+ metadata=constants.EXPERIMENT_METADATA,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ self = cls.__new__()
+ self._metadata_context = experiment_context
+
+ return self
+
+ @classmethod
+ def get_or_create(
+ cls,
+ experiment_name: str,
+ *,
+ description: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Experiment":
+ """Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
+
+ Otherwise creates this experiment.
+
+ ```
+ my_experiment = aiplatform.Experiment.get_or_create('my-experiment', description='my description')
+ ```
+
+ Args:
+ experiment_name (str): Required. The name of this experiment.
+ description (str): Optional. Describes this experiment's purpose.
+ project (str):
+ Optional. Project where this experiment will be retrieved from or created. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment will be retrieved from or created. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to retrieve or create this experiment. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ Vertex AI experiment.
+ """
+
+ metadata_store._MetadataStore.ensure_default_metadata_store_exists(
+ project=project, location=location, credentials=credentials
+ )
+
+ with _SetLoggerLevel(resource):
+ experiment_context = context._Context.get_or_create(
+ resource_id=experiment_name,
+ display_name=experiment_name,
+ description=description,
+ schema_title=constants.SYSTEM_EXPERIMENT,
+ schema_version=metadata._get_experiment_schema_version(),
+ metadata=constants.EXPERIMENT_METADATA,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ cls._validate_experiment_context(experiment_context)
+
+ if description and description != experiment_context.description:
+ experiment_context.update(description=description)
+
+ self = cls.__new__(cls)
+ self._metadata_context = experiment_context
+
+ return self
+
+ @classmethod
+ def list(
+ cls,
+ *,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["Experiment"]:
+ """List all Vertex AI Experiments in the given project.
+
+ ```
+ my_experiments = aiplatform.Experiment.list()
+ ```
+
+ Args:
+ project (str):
+ Optional. Project to list these experiments from. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to list these experiments from. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to list these experiments. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ List of Vertex AI experiments.
+ """
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=constants.SYSTEM_EXPERIMENT
+ )
+
+ with _SetLoggerLevel(resource):
+ experiment_contexts = context._Context.list(
+ filter=filter_str,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ experiments = []
+ for experiment_context in experiment_contexts:
+ # Filters Tensorboard Experiments
+ if not cls._is_tensorboard_experiment(experiment_context):
+ experiment = cls.__new__(cls)
+ experiment._metadata_context = experiment_context
+ experiments.append(experiment)
+ return experiments
+
+ @property
+ def resource_name(self) -> str:
+ """The Metadata context resource name of this experiment."""
+ return self._metadata_context.resource_name
+
+ def delete(self, *, delete_backing_tensorboard_runs: bool = False):
+ """Deletes this experiment all the experiment runs under this experiment
+
+ Does not delete Pipeline runs, Artifacts, or Executions associated to this experiment
+ or experiment runs in this experiment.
+
+ ```
+ my_experiment = aiplatform.Experiment('my-experiment')
+ my_experiment.delete(delete_backing_tensorboard_runs=True)
+ ```
+
+ Args:
+ delete_backing_tensorboard_runs (bool):
+ Optional. If True will also delete the Tensorboard Runs associated to the experiment
+ runs under this experiment that we used to store time series metrics.
+ """
+
+ experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context._Context][
+ constants.SYSTEM_EXPERIMENT_RUN
+ ].list(experiment=self)
+ for experiment_run in experiment_runs:
+ experiment_run.delete(
+ delete_backing_tensorboard_run=delete_backing_tensorboard_runs
+ )
+ self._metadata_context.delete()
+
+ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
+ """Get parameters, metrics, and time series metrics of all runs in this experiment as Dataframe.
+
+ ```
+ my_experiment = aiplatform.Experiment('my-experiment')
+ df = my_experiment.get_data_frame()
+ ```
+
+ Returns:
+ pd.DataFrame: Pandas Dataframe of Experiment Runs.
+
+ Raises:
+ ImportError: If pandas is not installed.
+ """
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "Pandas is not installed and is required to get dataframe as the return format. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[metadata]"'
+ )
+
+ service_request_args = dict(
+ project=self._metadata_context.project,
+ location=self._metadata_context.location,
+ credentials=self._metadata_context.credentials,
+ )
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=sorted(
+ list(_SUPPORTED_LOGGABLE_RESOURCES[context._Context].keys())
+ ),
+ parent_contexts=[self._metadata_context.resource_name],
+ )
+ contexts = context._Context.list(filter_str, **service_request_args)
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=list(
+ _SUPPORTED_LOGGABLE_RESOURCES[execution.Execution].keys()
+ ),
+ in_context=[self._metadata_context.resource_name],
+ )
+
+ executions = execution.Execution.list(filter_str, **service_request_args)
+
+ rows = []
+ for metadata_context in contexts:
+ row_dict = (
+ _SUPPORTED_LOGGABLE_RESOURCES[context._Context][
+ metadata_context.schema_title
+ ]
+ ._query_experiment_row(metadata_context)
+ .to_dict()
+ )
+ row_dict.update({"experiment_name": self.name})
+ rows.append(row_dict)
+
+ # backward compatibility
+ for metadata_execution in executions:
+ row_dict = (
+ _SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
+ metadata_execution.schema_title
+ ]
+ ._query_experiment_row(metadata_execution)
+ .to_dict()
+ )
+ row_dict.update({"experiment_name": self.name})
+ rows.append(row_dict)
+
+ df = pd.DataFrame(rows)
+
+ column_name_sort_map = {
+ "experiment_name": -1,
+ "run_name": 1,
+ "run_type": 2,
+ "state": 3,
+ }
+
+ def column_sort_key(key: str) -> int:
+ """Helper method to reorder columns."""
+ order = column_name_sort_map.get(key)
+ if order:
+ return order
+ elif key.startswith("param"):
+ return 5
+ elif key.startswith("metric"):
+ return 6
+ else:
+ return 7
+
+ columns = df.columns
+ columns = sorted(columns, key=column_sort_key)
+ df = df.reindex(columns, axis=1)
+
+ return df
+
+ def _lookup_backing_tensorboard(self) -> Optional[tensorboard_resource.Tensorboard]:
+ """Returns backing tensorboard if one is set.
+
+ Returns:
+ Tensorboard resource if one exists.
+ """
+ tensorboard_resource_name = self._metadata_context.metadata.get(
+ constants._BACKING_TENSORBOARD_RESOURCE_KEY
+ )
+
+ if not tensorboard_resource_name:
+ with _SetLoggerLevel(resource):
+ self._metadata_context.sync_resource()
+ tensorboard_resource_name = self._metadata_context.metadata.get(
+ constants._BACKING_TENSORBOARD_RESOURCE_KEY
+ )
+
+ if tensorboard_resource_name:
+ return tensorboard_resource.Tensorboard(
+ tensorboard_resource_name,
+ credentials=self._metadata_context.credentials,
+ )
+
+ def get_backing_tensorboard_resource(
+ self,
+ ) -> Optional[tensorboard_resource.Tensorboard]:
+ """Get the backing tensorboard for this experiment in one exists.
+
+ ```
+ my_experiment = aiplatform.Experiment('my-experiment')
+ tb = my_experiment.get_backing_tensorboard_resource()
+ ```
+
+ Returns:
+ Backing Tensorboard resource for this experiment if one exists.
+ """
+ return self._lookup_backing_tensorboard()
+
+ def assign_backing_tensorboard(
+ self, tensorboard: Union[tensorboard_resource.Tensorboard, str]
+ ):
+ """Assigns tensorboard as backing tensorboard to support time series metrics logging.
+
+ ```
+ tb = aiplatform.Tensorboard('tensorboard-resource-id')
+ my_experiment = aiplatform.Experiment('my-experiment')
+ my_experiment.assign_backing_tensorboard(tb)
+ ```
+
+ Args:
+ tensorboard (Union[aiplatform.Tensorboard, str]):
+ Required. Tensorboard resource or resource name to associate to this experiment.
+
+ Raises:
+ ValueError: If this experiment already has a previously set backing tensorboard resource.
+ ValueError: If Tensorboard is not in same project and location as this experiment.
+ """
+
+ backing_tensorboard = self._lookup_backing_tensorboard()
+ if backing_tensorboard:
+ tensorboard_resource_name = (
+ tensorboard
+ if isinstance(tensorboard, str)
+ else tensorboard.resource_name
+ )
+ if tensorboard_resource_name != backing_tensorboard.resource_name:
+ raise ValueError(
+ f"Experiment {self._metadata_context.name} already associated '"
+ f"to tensorboard resource {backing_tensorboard.resource_name}"
+ )
+
+ if isinstance(tensorboard, str):
+ tensorboard = tensorboard_resource.Tensorboard(
+ tensorboard,
+ project=self._metadata_context.project,
+ location=self._metadata_context.location,
+ credentials=self._metadata_context.credentials,
+ )
+
+ if tensorboard.project not in self._metadata_context._project_tuple:
+ raise ValueError(
+ f"Tensorboard is in project {tensorboard.project} but must be in project {self._metadata_context.project}"
+ )
+ if tensorboard.location != self._metadata_context.location:
+ raise ValueError(
+ f"Tensorboard is in location {tensorboard.location} but must be in location {self._metadata_context.location}"
+ )
+
+ self._metadata_context.update(
+ metadata={
+ constants._BACKING_TENSORBOARD_RESOURCE_KEY: tensorboard.resource_name
+ }
+ )
+
+ def _log_experiment_loggable(self, experiment_loggable: "_ExperimentLoggable"):
+ """Associates a Vertex resource that can be logged to an Experiment as run of this experiment.
+
+ Args:
+ experiment_loggable (_ExperimentLoggable):
+ A Vertex Resource that can be logged to an Experiment directly.
+ """
+ context = experiment_loggable._get_context()
+ self._metadata_context.add_context_children([context])
+
+
+class _SetLoggerLevel:
+ """Helper method to suppress logging."""
+
+ def __init__(self, module):
+ self._module = module
+
+ def __enter__(self):
+ logging.getLogger(self._module.__name__).setLevel(logging.WARNING)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ logging.getLogger(self._module.__name__).setLevel(logging.INFO)
+
+
+class _VertexResourceWithMetadata(NamedTuple):
+ """Represents a resource coupled with it's metadata representation"""
+
+ resource: base.VertexAiResourceNoun
+ metadata: Union[artifact.Artifact, execution.Execution, context._Context]
+
+
+class _ExperimentLoggableSchema(NamedTuple):
+ """Used with _ExperimentLoggable to capture Metadata representation information about resoure.
+
+ For example:
+ _ExperimentLoggableSchema(title='system.PipelineRun', type=context._Context)
+
+ Defines the schema and metadata type to lookup PipelineJobs.
+ """
+
+ title: str
+ type: Union[Type[context._Context], Type[execution.Execution]] = context._Context
+
+
+class _ExperimentLoggable(abc.ABC):
+ """Abstract base class to define a Vertex Resource as loggable against an Experiment.
+
+ For example:
+ class PipelineJob(..., experiment_loggable_schemas=
+ (_ExperimentLoggableSchema(title='system.PipelineRun'), )
+
+ """
+
+ def __init_subclass__(
+ cls, *, experiment_loggable_schemas: Tuple[_ExperimentLoggableSchema], **kwargs
+ ):
+ """Register the metadata_schema for the subclass so Experiment can use it to retrieve the associated types.
+
+ usage:
+
+ class PipelineJob(..., experiment_loggable_schemas=
+ (_ExperimentLoggableSchema(title='system.PipelineRun'), )
+
+ Args:
+ experiment_loggable_schemas:
+ Tuple of the schema_title and type pairs that represent this resource. Note that a single item in the
+ tuple will be most common. Currently only experiment run has multiple representation for backwards
+ compatibility. Almost all schemas should be Contexts and Execution is currently only supported
+ for backwards compatibility of experiment runs.
+
+ """
+ super().__init_subclass__(**kwargs)
+
+ # register the type when module is loaded
+ for schema in experiment_loggable_schemas:
+ _SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls
+
+ @abc.abstractmethod
+ def _get_context(self) -> context._Context:
+ """Should return the metadata context that represents this resource.
+
+ The subclass should enforce this context exists.
+
+ Returns:
+ Context that represents this resource.
+ """
+ pass
+
+ @classmethod
+ @abc.abstractmethod
+ def _query_experiment_row(
+ cls, node: Union[context._Context, execution.Execution]
+ ) -> _ExperimentRow:
+ """Should return parameters and metrics for this resource as a run row.
+
+ Args:
+ node: The metadata node that represents this resource.
+ Returns:
+ A populated run row for this resource.
+ """
+ pass
+
+ def _validate_experiment(self, experiment: Union[str, Experiment]):
+ """Validates experiment is accessible. Can be used by subclass to throw before creating the intended resource.
+
+ Args:
+ experiment (Union[str, Experiment]): The experiment that this resource will be associated to.
+
+ Raises:
+ RuntimeError: If service raises any exception when trying to access this experiment.
+ ValueError: If resource project or location do not match experiment project or location.
+ """
+
+ if isinstance(experiment, str):
+ try:
+ experiment = Experiment.get_or_create(
+ experiment,
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+ except Exception as e:
+ raise RuntimeError(
+ f"Experiment {experiment} could not be found or created. {self.__class__.__name__} not created"
+ ) from e
+
+ if self.project not in experiment._metadata_context._project_tuple:
+ raise ValueError(
+ f"{self.__class__.__name__} project {self.project} does not match experiment "
+ f"{experiment.name} project {experiment.project}"
+ )
+
+ if experiment._metadata_context.location != self.location:
+ raise ValueError(
+ f"{self.__class__.__name__} location {self.location} does not match experiment "
+ f"{experiment.name} location {experiment.location}"
+ )
+
+ def _associate_to_experiment(self, experiment: Union[str, Experiment]):
+ """Associates this resource to the provided Experiment.
+
+ Args:
+ experiment (Union[str, Experiment]): Required. Experiment name or experiment instance.
+
+ Raises:
+ RuntimeError: If Metadata service cannot associate resource to Experiment.
+ """
+ experiment_name = experiment if isinstance(experiment, str) else experiment.name
+ _LOGGER.info(
+ "Associating %s to Experiment: %s" % (self.resource_name, experiment_name)
+ )
+
+ try:
+ if isinstance(experiment, str):
+ experiment = Experiment.get_or_create(
+ experiment,
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+ experiment._log_experiment_loggable(self)
+ except Exception as e:
+ raise RuntimeError(
+ f"{self.resource_name} could not be associated with Experiment {experiment.name}"
+ ) from e
+
+
+# maps context names to their resources classes
+# used by the Experiment implementation to filter for representations in the metadata store
+# populated at module import time from class that inherit _ExperimentLoggable
+# example mapping:
+# {Metadata Type} -> {schema title} -> {vertex sdk class}
+# Context -> 'system.PipelineRun' -> aiplatform.PipelineJob
+# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
+# Execution -> 'system.Run' -> aiplatform.ExperimentRun
+_SUPPORTED_LOGGABLE_RESOURCES: Dict[
+ Union[Type[context._Context], Type[execution.Execution]],
+ Dict[str, _ExperimentLoggable],
+] = {execution.Execution: dict(), context._Context: dict()}
diff --git a/google/cloud/aiplatform/metadata/experiment_run_resource.py b/google/cloud/aiplatform/metadata/experiment_run_resource.py
new file mode 100644
index 0000000000..b10a9d5dcb
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/experiment_run_resource.py
@@ -0,0 +1,1197 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import concurrent.futures
+import functools
+from typing import Callable, Dict, List, Optional, Set, Union, Any
+
+from google.api_core import exceptions
+from google.auth import credentials as auth_credentials
+from google.protobuf import timestamp_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import pipeline_jobs
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.compat.types import execution as gca_execution
+from google.cloud.aiplatform.compat.types import (
+ tensorboard_time_series as gca_tensorboard_time_series,
+)
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import constants
+from google.cloud.aiplatform.metadata import context
+from google.cloud.aiplatform.metadata import execution
+from google.cloud.aiplatform.metadata import experiment_resources
+from google.cloud.aiplatform.metadata import metadata
+from google.cloud.aiplatform.metadata import resource
+from google.cloud.aiplatform.metadata import utils as metadata_utils
+from google.cloud.aiplatform.tensorboard import tensorboard_resource
+from google.cloud.aiplatform.utils import rest_utils
+
+
+_LOGGER = base.Logger(__name__)
+
+
+def _format_experiment_run_resource_id(experiment_name: str, run_name: str) -> str:
+ """Formats the the experiment run resource id as a concatenation of experiment name and run name.
+
+ Args:
+ experiment_name (str): Name of the experiment which is it's resource id.
+ run_name (str): Name of the run.
+ Returns:
+ The resource id to be used with this run.
+ """
+ return f"{experiment_name}-{run_name}"
+
+
+def _v1_not_supported(method: Callable) -> Callable:
+ """Helpers wrapper for backward compatibility. Raises when using an API not support for legacy runs."""
+
+ @functools.wraps(method)
+ def wrapper(self, *args, **kwargs):
+ if isinstance(self._metadata_node, execution.Execution):
+ raise NotImplementedError(
+ f"{self._run_name} is an Execution run created during Vertex Experiment Preview and does not support"
+ f" {method.__name__}. Please create a new Experiment run to use this method."
+ )
+ else:
+ return method(self, *args, **kwargs)
+
+ return wrapper
+
+
+class ExperimentRun(
+ experiment_resources._ExperimentLoggable,
+ experiment_loggable_schemas=(
+ experiment_resources._ExperimentLoggableSchema(
+ title=constants.SYSTEM_EXPERIMENT_RUN, type=context._Context
+ ),
+ # backwards compatibility with Preview Experiment runs
+ experiment_resources._ExperimentLoggableSchema(
+ title=constants.SYSTEM_RUN, type=execution.Execution
+ ),
+ ),
+):
+ """A Vertex AI Experiment run"""
+
+ def __init__(
+ self,
+ run_name: str,
+ experiment: Union[experiment_resources.Experiment, str],
+ *,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """
+
+ ```
+ my_run = aiplatform.ExperimentRun('my-run, experiment='my-experiment')
+ ```
+
+ Args:
+ run (str): Required. The name of this run.
+ experiment (Union[experiment_resources.Experiment, str]):
+ Required. The name or instance of this experiment.
+ project (str):
+ Optional. Project where this experiment run is located. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment run is located. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to retrieve this experiment run. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ self._experiment = self._get_experiment(
+ experiment=experiment,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ self._run_name = run_name
+
+ run_id = _format_experiment_run_resource_id(
+ experiment_name=self._experiment.name, run_name=run_name
+ )
+
+ metadata_args = dict(
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ def _get_context() -> context._Context:
+ with experiment_resources._SetLoggerLevel(resource):
+ run_context = context._Context(
+ **{**metadata_args, "resource_name": run_id}
+ )
+ if run_context.schema_title != constants.SYSTEM_EXPERIMENT_RUN:
+ raise ValueError(
+ f"Run {run_name} must be of type {constants.SYSTEM_EXPERIMENT_RUN}"
+ f" but is of type {run_context.schema_title}"
+ )
+ return run_context
+
+ try:
+ self._metadata_node = _get_context()
+ except exceptions.NotFound as context_not_found:
+ try:
+ # backward compatibility
+ self._v1_resolve_experiment_run(
+ {
+ **metadata_args,
+ "execution_name": run_id,
+ }
+ )
+ except exceptions.NotFound:
+ raise context_not_found
+ else:
+ self._backing_tensorboard_run = self._lookup_tensorboard_run_artifact()
+
+ # initially set to None. Will initially update from resource then track locally.
+ self._largest_step: Optional[int] = None
+
+ def _v1_resolve_experiment_run(self, metadata_args: Dict[str, Any]):
+ """Resolves preview Experiment.
+
+ Args:
+ metadata_args (Dict[str, Any): Arguments to pass to Execution constructor.
+ """
+
+ def _get_execution():
+ with experiment_resources._SetLoggerLevel(resource):
+ run_execution = execution.Execution(**metadata_args)
+ if run_execution.schema_title != constants.SYSTEM_RUN:
+ # note this will raise the context not found exception in the constructor
+ raise exceptions.NotFound("Experiment run not found.")
+ return run_execution
+
+ self._metadata_node = _get_execution()
+ self._metadata_metric_artifact = self._v1_get_metric_artifact()
+
+ def _v1_get_metric_artifact(self) -> artifact.Artifact:
+ """Resolves metric artifact for backward compatibility.
+
+ Returns:
+ Instance of Artifact that represents this run's metric artifact.
+ """
+ metadata_args = dict(
+ artifact_name=self._v1_format_artifact_name(self._metadata_node.name),
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+
+ with experiment_resources._SetLoggerLevel(resource):
+ metric_artifact = artifact.Artifact(**metadata_args)
+
+ if metric_artifact.schema_title != constants.SYSTEM_METRICS:
+ # note this will raise the context not found exception in the constructor
+ raise exceptions.NotFound("Experiment run not found.")
+
+ return metric_artifact
+
+ @staticmethod
+ def _v1_format_artifact_name(run_id: str) -> str:
+ """Formats resource id of legacy metric artifact for this run."""
+ return f"{run_id}-metrics"
+
+ def _get_context(self) -> context._Context:
+ """Returns this metadata context that represents this run.
+
+ Returns:
+ Context instance of this run.
+ """
+ return self._metadata_node
+
+ @property
+ def resource_id(self) -> str:
+ """The resource ID of this experiment run's Metadata context.
+
+ The resource ID is the final part of the resource name:
+ ``projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{resource ID}``
+ """
+ return self._metadata_node.name
+
+ @property
+ def name(self) -> str:
+ """This run's name used to identify this run within it's Experiment."""
+ return self._run_name
+
+ @property
+ def resource_name(self) -> str:
+ """This run's Metadata context resource name.
+
+ In the format: ``projects/{project}/locations/{location}/metadataStores/{metadatastore}/contexts/{context}``
+ """
+ return self._metadata_node.resource_name
+
+ @property
+ def project(self) -> str:
+ """The project that this experiment run is located in."""
+ return self._metadata_node.project
+
+ @property
+ def location(self) -> str:
+ """The location that this experiment is located in."""
+ return self._metadata_node.location
+
+ @property
+ def credentials(self) -> auth_credentials.Credentials:
+ """The credentials used to access this experiment run."""
+ return self._metadata_node.credentials
+
+ @property
+ def state(self) -> gca_execution.Execution.State:
+ """The state of this run."""
+ if self._is_legacy_experiment_run():
+ return self._metadata_node.state
+ else:
+ return getattr(
+ gca_execution.Execution.State,
+ self._metadata_node.metadata[constants._STATE_KEY],
+ )
+
+ @staticmethod
+ def _get_experiment(
+ experiment: Optional[Union[experiment_resources.Experiment, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> experiment_resources.Experiment:
+ """Helper method ot get the experiment by name(str) or instance.
+
+ Args:
+ experiment(str):
+ Optional. The name of this experiment. Defaults to experiment set in aiplatform.init if not provided.
+ project (str):
+ Optional. Project where this experiment is located. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment is located. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to retrieve this experiment. Overrides
+ credentials set in aiplatform.init.
+ Raises:
+ ValueError if experiment is None and experiment has not been set using aiplatform.init.
+ """
+
+ experiment = experiment or initializer.global_config.experiment
+
+ if not experiment:
+ raise ValueError(
+ "experiment must be provided or experiment should be set using aiplatform.init"
+ )
+
+ if not isinstance(experiment, experiment_resources.Experiment):
+ experiment = experiment_resources.Experiment(
+ experiment_name=experiment,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ return experiment
+
+ def _is_backing_tensorboard_run_artifact(self, artifact: artifact.Artifact) -> bool:
+ """Helper method to confirm tensorboard run metadata artifact is this run's tensorboard artifact.
+
+ Args:
+ artifact (artifact.Artifact): Required. Instance of metadata Artifact.
+ Returns:
+ bool whether the provided artifact is this run's TensorboardRun's artifact.
+ """
+ return all(
+ [
+ artifact.metadata.get(constants._VERTEX_EXPERIMENT_TRACKING_LABEL),
+ artifact.name == self._tensorboard_run_id(self._metadata_node.name),
+ artifact.schema_title
+ == constants._TENSORBOARD_RUN_REFERENCE_ARTIFACT.schema_title,
+ ]
+ )
+
+ def _is_legacy_experiment_run(self) -> bool:
+ """Helper method that return True if this is a legacy experiment run."""
+ return isinstance(self._metadata_node, execution.Execution)
+
+ def update_state(self, state: gca_execution.Execution.State):
+ """Update the state of this experiment run.
+
+ ```
+ my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
+ my_run.update_state(state=aiplatform.gapic.Execution.State.COMPLETE)
+ ```
+
+ Args:
+ state (aiplatform.gapic.Execution.State): State of this run.
+ """
+ if self._is_legacy_experiment_run():
+ self._metadata_node.update(state=state)
+ else:
+ self._metadata_node.update(metadata={constants._STATE_KEY: state.name})
+
+ def _lookup_tensorboard_run_artifact(
+ self,
+ ) -> Optional[experiment_resources._VertexResourceWithMetadata]:
+ """Helpers method to resolve this run's TensorboardRun Artifact if it exists.
+
+ Returns:
+ Tuple of Tensorboard Run Artifact and TensorboardRun is it exists.
+ """
+ with experiment_resources._SetLoggerLevel(resource):
+ try:
+ tensorboard_run_artifact = artifact.Artifact(
+ artifact_name=self._tensorboard_run_id(self._metadata_node.name),
+ project=self._metadata_node.project,
+ location=self._metadata_node.location,
+ credentials=self._metadata_node.credentials,
+ )
+ except exceptions.NotFound:
+ tensorboard_run_artifact = None
+
+ if tensorboard_run_artifact and self._is_backing_tensorboard_run_artifact(
+ tensorboard_run_artifact
+ ):
+ return experiment_resources._VertexResourceWithMetadata(
+ resource=tensorboard_resource.TensorboardRun(
+ tensorboard_run_artifact.metadata[
+ constants.GCP_ARTIFACT_RESOURCE_NAME_KEY
+ ]
+ ),
+ metadata=tensorboard_run_artifact,
+ )
+
+ @classmethod
+ def list(
+ cls,
+ *,
+ experiment: Optional[Union[experiment_resources.Experiment, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["ExperimentRun"]:
+ """List the experiment runs for a given aiplatform.Experiment.
+
+ ```
+ my_runs = aiplatform.ExperimentRun.list(experiment='my-experiment')
+ ```
+
+ Args:
+ experiment (Union[aiplatform.Experiment, str]):
+ Optional. The experiment name or instance to list the experiment run from. If not provided,
+ will use the experiment set in aiplatform.init.
+ project (str):
+ Optional. Project where this experiment is located. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment is located. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to retrieve this experiment. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ List of experiment runs.
+ """
+
+ experiment = cls._get_experiment(
+ experiment=experiment,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ metadata_args = dict(
+ project=experiment._metadata_context.project,
+ location=experiment._metadata_context.location,
+ credentials=experiment._metadata_context.credentials,
+ )
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=constants.SYSTEM_EXPERIMENT_RUN,
+ parent_contexts=[experiment.resource_name],
+ )
+
+ run_contexts = context._Context.list(filter=filter_str, **metadata_args)
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=constants.SYSTEM_RUN, in_context=[experiment.resource_name]
+ )
+
+ run_executions = execution.Execution.list(filter=filter_str, **metadata_args)
+
+ def _initialize_experiment_run(context: context._Context) -> ExperimentRun:
+ this_experiment_run = cls.__new__(cls)
+ this_experiment_run._experiment = experiment
+ this_experiment_run._run_name = context.display_name
+ this_experiment_run._metadata_node = context
+
+ with experiment_resources._SetLoggerLevel(resource):
+ tb_run = this_experiment_run._lookup_tensorboard_run_artifact()
+ if tb_run:
+ this_experiment_run._backing_tensorboard_run = tb_run
+ else:
+ this_experiment_run._backing_tensorboard_run = None
+
+ this_experiment_run._largest_step = None
+
+ return this_experiment_run
+
+ def _initialize_v1_experiment_run(
+ execution: execution.Execution,
+ ) -> ExperimentRun:
+ this_experiment_run = cls.__new__(cls)
+ this_experiment_run._experiment = experiment
+ this_experiment_run._run_name = execution.display_name
+ this_experiment_run._metadata_node = execution
+ this_experiment_run._metadata_metric_artifact = (
+ this_experiment_run._v1_get_metric_artifact()
+ )
+
+ return this_experiment_run
+
+ if run_contexts or run_executions:
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max([len(run_contexts), len(run_executions)])
+ ) as executor:
+ submissions = [
+ executor.submit(_initialize_experiment_run, context)
+ for context in run_contexts
+ ]
+ experiment_runs = [submission.result() for submission in submissions]
+
+ submissions = [
+ executor.submit(_initialize_v1_experiment_run, execution)
+ for execution in run_executions
+ ]
+
+ for submission in submissions:
+ experiment_runs.append(submission.result())
+
+ return experiment_runs
+ else:
+ return []
+
+ @classmethod
+ def _query_experiment_row(
+ cls, node: Union[context._Context, execution.Execution]
+ ) -> experiment_resources._ExperimentRow:
+ """Retrieves the runs metric and parameters into an experiment run row.
+
+ Args:
+ node (Union[context._Context, execution.Execution]):
+ Required. Metadata node instance that represents this run.
+ Returns:
+ Experiment run row that represents this run.
+ """
+ this_experiment_run = cls.__new__(cls)
+ this_experiment_run._metadata_node = node
+
+ row = experiment_resources._ExperimentRow(
+ experiment_run_type=node.schema_title,
+ name=node.display_name,
+ )
+
+ if isinstance(node, context._Context):
+ this_experiment_run._backing_tensorboard_run = (
+ this_experiment_run._lookup_tensorboard_run_artifact()
+ )
+ row.params = node.metadata[constants._PARAM_KEY]
+ row.metrics = node.metadata[constants._METRIC_KEY]
+ row.time_series_metrics = (
+ this_experiment_run._get_latest_time_series_metric_columns()
+ )
+ row.state = node.metadata[constants._STATE_KEY]
+ else:
+ this_experiment_run._metadata_metric_artifact = (
+ this_experiment_run._v1_get_metric_artifact()
+ )
+ row.params = node.metadata
+ row.metrics = this_experiment_run._metadata_metric_artifact.metadata
+ row.state = node.state.name
+ return row
+
+ def _get_logged_pipeline_runs(self) -> List[context._Context]:
+ """Returns Pipeline Run contexts logged to this Experiment Run.
+
+ Returns:
+ List of Pipeline system.PipelineRun contexts.
+ """
+
+ service_request_args = dict(
+ project=self._metadata_node.project,
+ location=self._metadata_node.location,
+ credentials=self._metadata_node.credentials,
+ )
+
+ filter_str = metadata_utils._make_filter_string(
+ schema_title=constants.SYSTEM_PIPELINE_RUN,
+ parent_contexts=[self._metadata_node.resource_name],
+ )
+
+ return context._Context.list(filter=filter_str, **service_request_args)
+
+ def _get_latest_time_series_metric_columns(self) -> Dict[str, Union[float, int]]:
+ """Determines the latest step for each time series metric.
+
+ Returns:
+ Dictionary mapping time series metric key to the latest step of that metric.
+ """
+ if self._backing_tensorboard_run:
+ time_series_metrics = (
+ self._backing_tensorboard_run.resource.read_time_series_data()
+ )
+
+ return {
+ display_name: data.values[-1].scalar.value
+ for display_name, data in time_series_metrics.items()
+ if data.value_type
+ == gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
+ }
+ return {}
+
+ def _log_pipeline_job(self, pipeline_job: pipeline_jobs.PipelineJob):
+ """Associate this PipelineJob's Context to the current ExperimentRun Context as a child context.
+
+ Args:
+ pipeline_job (pipeline_jobs.PipelineJob):
+ Required. The PipelineJob to associate.
+ """
+
+ pipeline_job_context = pipeline_job._get_context()
+ self._metadata_node.add_context_children([pipeline_job_context])
+
+ @_v1_not_supported
+ def log(
+ self,
+ *,
+ pipeline_job: Optional[pipeline_jobs.PipelineJob] = None,
+ ):
+ """Log a Vertex Resource to this experiment run.
+
+ ```
+ my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
+ my_job = aiplatform.PipelineJob(...)
+ my_job.submit()
+ my_run.log(my_job)
+ ```
+
+ Args:
+ pipeline_job (aiplatform.PipelineJob): Optional. A Vertex PipelineJob.
+ """
+ if pipeline_job:
+ self._log_pipeline_job(pipeline_job=pipeline_job)
+
+ @staticmethod
+ def _validate_run_id(run_id: str):
+ """Validates the run id
+
+ Args:
+ run_id(str): Required. The run id to validate.
+ Raises:
+ ValueError if run id is too long.
+ """
+
+ if len(run_id) > 128:
+ raise ValueError(
+ f"Length of Experiment ID and Run ID cannot be greater than 128. "
+ f"{run_id} is of length {len(run_id)}"
+ )
+
+ @classmethod
+ def create(
+ cls,
+ run_name: str,
+ *,
+ experiment: Optional[Union[experiment_resources.Experiment, str]] = None,
+ tensorboard: Optional[Union[tensorboard_resource.Tensorboard, str]] = None,
+ state: gca_execution.Execution.State = gca_execution.Execution.State.RUNNING,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "ExperimentRun":
+ """Creates a new experiment run in Vertex AI Experiments.
+
+ ```
+ my_run = aiplatform.ExperimentRun.create('my-run', experiment='my-experiment')
+ ```
+
+ Args:
+ run_name (str): Required. The name of this run.
+ experiment (Union[aiplatform.Experiment, str]):
+ Optional. The name or instance of the experiment to create this run under.
+ If not provided, will default to the experiment set in `aiplatform.init`.
+ tensorboard (Union[aiplatform.Tensorboard, str]):
+ Optional. The resource name or instance of Vertex Tensorbaord to use as the backing
+ Tensorboard for time series metric logging. If not provided, will default to the
+ the backing tensorboard of parent experiment if set. Must be in same project and location
+ as this experiment run.
+ state (aiplatform.gapic.Execution.State):
+ Optional. The state of this run. Defaults to RUNNING.
+ project (str):
+ Optional. Project where this experiment will be created. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location where this experiment will be created. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this experiment. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ The newly created experiment run.
+ """
+
+ experiment = cls._get_experiment(experiment)
+
+ run_id = _format_experiment_run_resource_id(
+ experiment_name=experiment.name, run_name=run_name
+ )
+
+ cls._validate_run_id(run_id)
+
+ def _create_context():
+ with experiment_resources._SetLoggerLevel(resource):
+ return context._Context._create(
+ resource_id=run_id,
+ display_name=run_name,
+ schema_title=constants.SYSTEM_EXPERIMENT_RUN,
+ schema_version=constants.SCHEMA_VERSIONS[
+ constants.SYSTEM_EXPERIMENT_RUN
+ ],
+ metadata={
+ constants._PARAM_KEY: {},
+ constants._METRIC_KEY: {},
+ constants._STATE_KEY: state.name,
+ },
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ metadata_context = _create_context()
+
+ if metadata_context is None:
+ raise RuntimeError(
+ f"Experiment Run with name {run_name} in {experiment.name} already exists."
+ )
+
+ experiment_run = cls.__new__(cls)
+ experiment_run._experiment = experiment
+ experiment_run._run_name = metadata_context.display_name
+ experiment_run._metadata_node = metadata_context
+ experiment_run._backing_tensorboard_run = None
+ experiment_run._largest_step = None
+
+ if tensorboard:
+ cls._assign_backing_tensorboard(
+ self=experiment_run, tensorboard=tensorboard
+ )
+ else:
+ cls._assign_to_experiment_backing_tensorboard(self=experiment_run)
+
+ experiment_run._associate_to_experiment(experiment)
+ return experiment_run
+
+ def _assign_to_experiment_backing_tensorboard(self):
+ """Assigns parent Experiment backing tensorboard resource to this Experiment Run."""
+ backing_tensorboard_resource = (
+ self._experiment.get_backing_tensorboard_resource()
+ )
+
+ if backing_tensorboard_resource:
+ self.assign_backing_tensorboard(tensorboard=backing_tensorboard_resource)
+
+ @staticmethod
+ def _format_tensorboard_experiment_display_name(experiment_name: str) -> str:
+ """Formats Tensorboard experiment name that backs this run.
+ Args:
+ experiment_name (str): Required. The name of the experiment.
+ Returns:
+ Formatted Tensorboard Experiment name
+ """
+ # post fix helps distinguish from the Vertex Experiment in console
+ return f"{experiment_name} Backing Tensorboard Experiment"
+
+ def _assign_backing_tensorboard(
+ self, tensorboard: Union[tensorboard_resource.Tensorboard, str]
+ ):
+ """Assign tensorboard as the backing tensorboard to this run.
+
+ Args:
+ tensorboard (Union[tensorboard_resource.Tensorboard, str]):
+ Required. Tensorboard instance or resource name.
+ """
+ if isinstance(tensorboard, str):
+ tensorboard = tensorboard_resource.Tensorboard(
+ tensorboard, credentials=self._metadata_node.credentials
+ )
+
+ tensorboard_resource_name_parts = tensorboard._parse_resource_name(
+ tensorboard.resource_name
+ )
+ tensorboard_experiment_resource_name = (
+ tensorboard_resource.TensorboardExperiment._format_resource_name(
+ experiment=self._experiment.name, **tensorboard_resource_name_parts
+ )
+ )
+ try:
+ tensorboard_experiment = tensorboard_resource.TensorboardExperiment(
+ tensorboard_experiment_resource_name,
+ credentials=tensorboard.credentials,
+ )
+ except exceptions.NotFound:
+ with experiment_resources._SetLoggerLevel(tensorboard_resource):
+ tensorboard_experiment = (
+ tensorboard_resource.TensorboardExperiment.create(
+ tensorboard_experiment_id=self._experiment.name,
+ display_name=self._format_tensorboard_experiment_display_name(
+ self._experiment.name
+ ),
+ tensorboard_name=tensorboard.resource_name,
+ credentials=tensorboard.credentials,
+ )
+ )
+
+ tensorboard_experiment_name_parts = tensorboard_experiment._parse_resource_name(
+ tensorboard_experiment.resource_name
+ )
+ tensorboard_run_resource_name = (
+ tensorboard_resource.TensorboardRun._format_resource_name(
+ run=self._run_name, **tensorboard_experiment_name_parts
+ )
+ )
+ try:
+ tensorboard_run = tensorboard_resource.TensorboardRun(
+ tensorboard_run_resource_name
+ )
+ except exceptions.NotFound:
+ with experiment_resources._SetLoggerLevel(tensorboard_resource):
+ tensorboard_run = tensorboard_resource.TensorboardRun.create(
+ tensorboard_run_id=self._run_name,
+ tensorboard_experiment_name=tensorboard_experiment.resource_name,
+ credentials=tensorboard.credentials,
+ )
+
+ gcp_resource_url = rest_utils.make_gcp_resource_rest_url(tensorboard_run)
+
+ with experiment_resources._SetLoggerLevel(resource):
+ tensorboard_run_metadata_artifact = artifact.Artifact._create(
+ uri=gcp_resource_url,
+ resource_id=self._tensorboard_run_id(self._metadata_node.name),
+ metadata={
+ "resourceName": tensorboard_run.resource_name,
+ constants._VERTEX_EXPERIMENT_TRACKING_LABEL: True,
+ },
+ schema_title=constants._TENSORBOARD_RUN_REFERENCE_ARTIFACT.schema_title,
+ schema_version=constants._TENSORBOARD_RUN_REFERENCE_ARTIFACT.schema_version,
+ state=gca_artifact.Artifact.State.LIVE,
+ )
+
+ self._metadata_node.add_artifacts_and_executions(
+ artifact_resource_names=[tensorboard_run_metadata_artifact.resource_name]
+ )
+
+ self._backing_tensorboard_run = (
+ experiment_resources._VertexResourceWithMetadata(
+ resource=tensorboard_run, metadata=tensorboard_run_metadata_artifact
+ )
+ )
+
+ @staticmethod
+ def _tensorboard_run_id(run_id: str) -> str:
+ """Helper method to format the tensorboard run artifact resource id for a run.
+
+ Args:
+ run_id: The resource id of the experiment run.
+
+ Returns:
+ Resource id for the associated tensorboard run artifact.
+ """
+ return f"{run_id}-tb-run"
+
+ @_v1_not_supported
+ def assign_backing_tensorboard(
+ self, tensorboard: Union[tensorboard_resource.Tensorboard, str]
+ ):
+ """Assigns tensorboard as backing tensorboard to support timeseries metrics logging for this run.
+
+ Args:
+ tensorboard (Union[aiplatform.Tensorboard, str]):
+ Required. Tensorboard instance or resource name.
+ """
+
+ backing_tensorboard = self._lookup_tensorboard_run_artifact()
+ if backing_tensorboard:
+ raise ValueError(
+ f"Experiment run {self._run_name} already associated to tensorboard resource {backing_tensorboard.resource.resource_name}"
+ )
+
+ self._assign_backing_tensorboard(tensorboard=tensorboard)
+
+ def _get_latest_time_series_step(self) -> int:
+ """Gets latest time series step of all time series from Tensorboard resource.
+
+ Returns:
+ Latest step of all time series metrics.
+ """
+ data = self._backing_tensorboard_run.resource.read_time_series_data()
+ return max(ts.values[-1].step if ts.values else 0 for ts in data.values())
+
+ @_v1_not_supported
+ def log_time_series_metrics(
+ self,
+ metrics: Dict[str, float],
+ step: Optional[int] = None,
+ wall_time: Optional[timestamp_pb2.Timestamp] = None,
+ ):
+ """Logs time series metrics to backing TensorboardRun of this Experiment Run.
+
+ ```
+ run.log_time_series_metrics({'accuracy': 0.9}, step=10)
+ ```
+
+ Args:
+ metrics (Dict[str, Union[str, float]]):
+ Required. Dictionary of where keys are metric names and values are metric values.
+ step (int):
+ Optional. Step index of this data point within the run.
+
+ If not provided, the latest
+ step amongst all time series metrics already logged will be used.
+ wall_time (timestamp_pb2.Timestamp):
+ Optional. Wall clock timestamp when this data point is
+ generated by the end user.
+
+ If not provided, this will be generated based on the value from time.time()
+ Raises:
+ RuntimeError: If current experiment run doesn't have a backing Tensorboard resource.
+ """
+
+ if not self._backing_tensorboard_run:
+ self._assign_to_experiment_backing_tensorboard()
+ if not self._backing_tensorboard_run:
+ raise RuntimeError(
+ "Please set this experiment run with backing tensorboard resource to use log_time_series_metrics."
+ )
+
+ self._soft_create_time_series(metric_keys=set(metrics.keys()))
+
+ if not step:
+ step = self._largest_step or self._get_latest_time_series_step()
+ step += 1
+ self._largest_step = step
+
+ self._backing_tensorboard_run.resource.write_tensorboard_scalar_data(
+ time_series_data=metrics, step=step, wall_time=wall_time
+ )
+
+ def _soft_create_time_series(self, metric_keys: Set[str]):
+ """Creates TensorboardTimeSeries for the metric keys if one currently does not exist.
+
+ Args:
+ metric_keys (Set[str]): Keys of the metrics.
+ """
+
+ if any(
+ key
+ not in self._backing_tensorboard_run.resource._time_series_display_name_to_id_mapping
+ for key in metric_keys
+ ):
+ self._backing_tensorboard_run.resource._sync_time_series_display_name_to_id_mapping()
+
+ for key in metric_keys:
+ if (
+ key
+ not in self._backing_tensorboard_run.resource._time_series_display_name_to_id_mapping
+ ):
+ with experiment_resources._SetLoggerLevel(tensorboard_resource):
+ self._backing_tensorboard_run.resource.create_tensorboard_time_series(
+ display_name=key
+ )
+
+ def log_params(self, params: Dict[str, Union[float, int, str]]):
+ """Log single or multiple parameters with specified key value pairs.
+
+ Parameters with the same key will be overwritten.
+
+ ```
+ my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
+ my_run.log_params({'learning_rate': 0.1, 'dropout_rate': 0.2})
+ ```
+
+ Args:
+ params (Dict[str, Union[float, int, str]]):
+ Required. Parameter key/value pairs.
+
+ Raises:
+ ValueError: If key is not str or value is not float, int, str.
+ """
+ # query the latest run execution resource before logging.
+ for key, value in params.items():
+ if not isinstance(key, str):
+ raise TypeError(
+ f"{key} is of type {type(key).__name__} must of type str"
+ )
+ if not isinstance(value, (float, int, str)):
+ raise TypeError(
+ f"Value for key {key} is of type {type(value).__name__} but must be one of float, int, str"
+ )
+
+ if self._is_legacy_experiment_run():
+ self._metadata_node.update(metadata=params)
+ else:
+ self._metadata_node.update(metadata={constants._PARAM_KEY: params})
+
+ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
+ """Log single or multiple Metrics with specified key and value pairs.
+
+ Metrics with the same key will be overwritten.
+
+ ```
+ my_run = aiplatform.ExperimentRun('my-run', experiment='my-experiment')
+ my_run.log_metrics({'accuracy': 0.9, 'recall': 0.8})
+ ```
+
+ Args:
+ metrics (Dict[str, Union[float, int]]):
+ Required. Metrics key/value pairs.
+ Raises:
+ TypeError: If keys are not str or values are not float, int, or str.
+ """
+ for key, value in metrics.items():
+ if not isinstance(key, str):
+ raise TypeError(
+ f"{key} is of type {type(key).__name__} must of type str"
+ )
+ if not isinstance(value, (float, int, str)):
+ raise TypeError(
+ f"Value for key {key} is of type {type(value).__name__} but must be one of float, int, str"
+ )
+
+ if self._is_legacy_experiment_run():
+ self._metadata_metric_artifact.update(metadata=metrics)
+ else:
+ # TODO: query the latest metrics artifact resource before logging.
+ self._metadata_node.update(metadata={constants._METRIC_KEY: metrics})
+
+ @_v1_not_supported
+ def get_time_series_data_frame(self) -> "pd.DataFrame": # noqa: F821
+ """Returns all time series in this Run as a DataFrame.
+
+ Returns:
+ pd.DataFrame: Time series metrics in this Run as a Dataframe.
+ """
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "Pandas is not installed and is required to get dataframe as the return format. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[metadata]"'
+ )
+
+ if not self._backing_tensorboard_run:
+ return pd.DataFrame({})
+ data = self._backing_tensorboard_run.resource.read_time_series_data()
+
+ if not data:
+ return pd.DataFrame({})
+
+ return (
+ pd.DataFrame(
+ {
+ name: entry.scalar.value,
+ "step": entry.step,
+ "wall_time": entry.wall_time,
+ }
+ for name, ts in data.items()
+ for entry in ts.values
+ )
+ .groupby(["step", "wall_time"])
+ .first()
+ .reset_index()
+ )
+
+ @_v1_not_supported
+ def get_logged_pipeline_jobs(self) -> List[pipeline_jobs.PipelineJob]:
+ """Get all PipelineJobs associated to this experiment run.
+
+ Returns:
+ List of PipelineJobs associated this run.
+ """
+
+ pipeline_job_contexts = self._get_logged_pipeline_runs()
+
+ return [
+ pipeline_jobs.PipelineJob.get(
+ c.display_name,
+ project=c.project,
+ location=c.location,
+ credentials=c.credentials,
+ )
+ for c in pipeline_job_contexts
+ ]
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ state = (
+ gca_execution.Execution.State.FAILED
+ if exc_type
+ else gca_execution.Execution.State.COMPLETE
+ )
+
+ if metadata._experiment_tracker.experiment_run is self:
+ metadata._experiment_tracker.end_run(state=state)
+ else:
+ self.end_run(state)
+
+ def end_run(
+ self,
+ *,
+ state: gca_execution.Execution.State = gca_execution.Execution.State.COMPLETE,
+ ):
+ """Ends this experiment run and sets state to COMPLETE.
+
+ Args:
+ state (aiplatform.gapic.Execution.State):
+ Optional. Override the state at the end of run. Defaults to COMPLETE.
+ """
+ self.update_state(state)
+
+ def delete(self, *, delete_backing_tensorboard_run: bool = False):
+ """Deletes this experiment run.
+
+ Does not delete the executions, artifacts, or resources logged to this run.
+
+ Args:
+ delete_backing_tensorboard_run (bool):
+ Optional. Whether to delete the backing tensorboard run that stores time series metrics for this run.
+ """
+ if delete_backing_tensorboard_run:
+ if not self._is_legacy_experiment_run():
+ if not self._backing_tensorboard_run:
+ self._backing_tensorboard_run = (
+ self._lookup_tensorboard_run_artifact()
+ )
+ if self._backing_tensorboard_run:
+ self._backing_tensorboard_run.resource.delete()
+ self._backing_tensorboard_run.metadata.delete()
+ else:
+ _LOGGER.warn(
+ f"Experiment run {self.name} does not have a backing tensorboard run."
+ " Skipping deletion."
+ )
+ else:
+ _LOGGER.warn(
+ f"Experiment run {self.name} does not have a backing tensorboard run."
+ " Skipping deletion."
+ )
+
+ self._metadata_node.delete()
+
+ if self._is_legacy_experiment_run():
+ self._metadata_metric_artifact.delete()
+
+ @_v1_not_supported
+ def get_artifacts(self) -> List[artifact.Artifact]:
+ """Get the list of artifacts associated to this run.
+
+ Returns:
+ List of artifacts associated to this run.
+ """
+ return self._metadata_node.get_artifacts()
+
+ @_v1_not_supported
+ def get_executions(self) -> List[execution.Execution]:
+ """Get the List of Executions associated to this run
+
+ Returns:
+ List of executions associated to this run.
+ """
+ return self._metadata_node.get_executions()
+
+ def get_params(self) -> Dict[str, Union[int, float, str]]:
+ """Get the parameters logged to this run.
+
+ Returns:
+ Parameters logged to this experiment run.
+ """
+ if self._is_legacy_experiment_run():
+ return self._metadata_node.metadata
+ else:
+ return self._metadata_node.metadata[constants._PARAM_KEY]
+
+ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
+ """Get the summary metrics logged to this run.
+
+ Returns:
+ Summary metrics logged to this experiment run.
+ """
+ if self._is_legacy_experiment_run():
+ return self._metadata_metric_artifact.metadata
+ else:
+ return self._metadata_node.metadata[constants._METRIC_KEY]
+
+ @_v1_not_supported
+ def associate_execution(self, execution: execution.Execution):
+ """Associate an execution to this experiment run.
+
+ Args:
+ execution (aiplatform.Execution): Execution to associate to this run.
+ """
+ self._metadata_node.add_artifacts_and_executions(
+ execution_resource_names=[execution.resource_name]
+ )
+
+ def _association_wrapper(self, f: Callable[..., Any]) -> Callable[..., Any]:
+ """Wraps methods and automatically associates all passed in Artifacts or Executions to this ExperimentRun.
+
+ This is used to wrap artifact passing methods of Executions so they get associated to this run.
+ """
+
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ artifacts = []
+ executions = []
+ for value in [*args, *kwargs.values()]:
+ value = value if isinstance(value, collections.Iterable) else [value]
+ for item in value:
+ if isinstance(item, execution.Execution):
+ executions.append(item)
+ elif isinstance(item, artifact.Artifact):
+ artifacts.append(item)
+ elif artifact._VertexResourceArtifactResolver.supports_metadata(
+ item
+ ):
+ artifacts.append(
+ artifact._VertexResourceArtifactResolver.resolve_or_create_resource_artifact(
+ item
+ )
+ )
+
+ if artifacts or executions:
+ self._metadata_node.add_artifacts_and_executions(
+ artifact_resource_names=[a.resource_name for a in artifacts],
+ execution_resource_names=[e.resource_name for e in executions],
+ )
+
+ result = f(*args, **kwargs)
+ return result
+
+ return wrapper
diff --git a/google/cloud/aiplatform/metadata/metadata.py b/google/cloud/aiplatform/metadata/metadata.py
index 919eff8619..f321a622b3 100644
--- a/google/cloud/aiplatform/metadata/metadata.py
+++ b/google/cloud/aiplatform/metadata/metadata.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,363 +15,634 @@
# limitations under the License.
#
-from typing import Dict, Union, Optional
+from typing import Dict, Union, Optional, Any
+
+from google.api_core import exceptions
+from google.auth import credentials as auth_credentials
+from google.protobuf import timestamp_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import gapic
+from google.cloud.aiplatform import pipeline_jobs
+from google.cloud.aiplatform.compat.types import execution as gca_execution
from google.cloud.aiplatform.metadata import constants
-from google.cloud.aiplatform.metadata.artifact import _Artifact
-from google.cloud.aiplatform.metadata.context import _Context
-from google.cloud.aiplatform.metadata.execution import _Execution
-from google.cloud.aiplatform.metadata.metadata_store import _MetadataStore
+from google.cloud.aiplatform.metadata import context
+from google.cloud.aiplatform.metadata import execution
+from google.cloud.aiplatform.metadata import experiment_resources
+from google.cloud.aiplatform.metadata import experiment_run_resource
+from google.cloud.aiplatform.tensorboard import tensorboard_resource
+
+_LOGGER = base.Logger(__name__)
+
+
+def _get_experiment_schema_version() -> str:
+ """Helper method to get experiment schema version
+ Returns:
+ str: schema version of the currently set experiment tracking version
+ """
+ return constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT]
-class _MetadataService:
+
+# Legacy Experiment tracking
+# Maintaining creation APIs for backwards compatibility testing
+class _LegacyExperimentService:
"""Contains the exposed APIs to interact with the Managed Metadata Service."""
+ @staticmethod
+ def get_pipeline_df(pipeline: str) -> "pd.DataFrame": # noqa: F821
+ """Returns a Pandas DataFrame of the parameters and metrics associated with one pipeline.
+
+ Args:
+ pipeline: Name of the Pipeline to filter results.
+
+ Returns:
+ Pandas Dataframe of Pipeline with metrics and parameters.
+ """
+
+ source = "pipeline"
+ pipeline_resource_name = (
+ _LegacyExperimentService._get_experiment_or_pipeline_resource_name(
+ name=pipeline, source=source, expected_schema=constants.SYSTEM_PIPELINE
+ )
+ )
+
+ return _LegacyExperimentService._query_runs_to_data_frame(
+ context_id=pipeline,
+ context_resource_name=pipeline_resource_name,
+ source=source,
+ )
+
+ @staticmethod
+ def _get_experiment_or_pipeline_resource_name(
+ name: str, source: str, expected_schema: str
+ ) -> str:
+ """Get the full resource name of the Context representing an Experiment or Pipeline.
+
+ Args:
+ name (str):
+ Name of the Experiment or Pipeline.
+ source (str):
+ Identify whether the this is an Experiment or a Pipeline.
+ expected_schema (str):
+ expected_schema identifies the expected schema used for Experiment or Pipeline.
+
+ Returns:
+ The full resource name of the Experiment or Pipeline Context.
+
+ Raise:
+ NotFound exception if experiment or pipeline does not exist.
+ """
+
+ this_context = context._Context(resource_name=name)
+
+ if this_context.schema_title != expected_schema:
+ raise ValueError(
+ f"Please provide a valid {source} name. {name} is not a {source}."
+ )
+ return this_context.resource_name
+
+ @staticmethod
+ def _query_runs_to_data_frame(
+ context_id: str, context_resource_name: str, source: str
+ ) -> "pd.DataFrame": # noqa: F821
+ """Get metrics and parameters associated with a given Context into a Dataframe.
+
+ Args:
+ context_id (str):
+ Name of the Experiment or Pipeline.
+ context_resource_name (str):
+ Full resource name of the Context associated with an Experiment or Pipeline.
+ source (str):
+ Identify whether the this is an Experiment or a Pipeline.
+
+ Returns:
+ The full resource name of the Experiment or Pipeline Context.
+ """
+
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "Pandas is not installed and is required to get dataframe as the return format. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[metadata]"'
+ )
+
+ filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{context_resource_name}")'
+ run_executions = execution.Execution.list(filter=filter)
+
+ context_summary = []
+ for run_execution in run_executions:
+ run_dict = {
+ f"{source}_name": context_id,
+ "run_name": run_execution.display_name,
+ }
+ run_dict.update(
+ _LegacyExperimentService._execution_to_column_named_metadata(
+ "param", run_execution.metadata
+ )
+ )
+
+ for metric_artifact in run_execution.get_output_artifacts():
+ run_dict.update(
+ _LegacyExperimentService._execution_to_column_named_metadata(
+ "metric", metric_artifact.metadata
+ )
+ )
+
+ context_summary.append(run_dict)
+
+ return pd.DataFrame(context_summary)
+
+ @staticmethod
+ def _execution_to_column_named_metadata(
+ metadata_type: str, metadata: Dict, filter_prefix: Optional[str] = None
+ ) -> Dict[str, Union[int, float, str]]:
+ """Returns a dict of the Execution/Artifact metadata with column names.
+
+ Args:
+ metadata_type: The type of this execution properties (param, metric).
+ metadata: Either an Execution or Artifact metadata field.
+ filter_prefix:
+ Remove this prefix from the key of metadata field. Mainly used for removing
+ "input:" from PipelineJob parameter keys
+
+ Returns:
+ Dict of custom properties with keys mapped to column names
+ """
+ column_key_to_value = {}
+ for key, value in metadata.items():
+ if filter_prefix and key.startswith(filter_prefix):
+ key = key[len(filter_prefix) :]
+ column_key_to_value[".".join([metadata_type, key])] = value
+
+ return column_key_to_value
+
+
+class _ExperimentTracker:
+ """Tracks Experiments and Experiment Runs wil high level APIs"""
+
def __init__(self):
- self._experiment = None
- self._run = None
- self._metrics = None
+ self._experiment: Optional[experiment_resources.Experiment] = None
+ self._experiment_run: Optional[experiment_run_resource.ExperimentRun] = None
def reset(self):
- """Reset all _MetadataService fields to None"""
+ """Resets this experiment tracker, clearing the current experiment and run."""
self._experiment = None
- self._run = None
- self._metrics = None
+ self._experiment_run = None
@property
def experiment_name(self) -> Optional[str]:
- """Return the experiment name of the _MetadataService, if experiment is not set, return None"""
+ """Return the currently set experiment name, if experiment is not set, return None"""
if self._experiment:
- return self._experiment.display_name
+ return self._experiment.name
return None
@property
- def run_name(self) -> Optional[str]:
- """Return the run name of the _MetadataService, if run is not set, return None"""
- if self._run:
- return self._run.display_name
- return None
+ def experiment(self) -> Optional[experiment_resources.Experiment]:
+ "Returns the currently set Experiment."
+ return self._experiment
- def set_experiment(self, experiment: str, description: Optional[str] = None):
- """Setup a experiment to current session.
+ @property
+ def experiment_run(self) -> Optional[experiment_run_resource.ExperimentRun]:
+ """Returns the currently set experiment run."""
+ return self._experiment_run
+
+ def set_experiment(
+ self,
+ experiment: str,
+ *,
+ description: Optional[str] = None,
+ backing_tensorboard: Optional[
+ Union[str, tensorboard_resource.Tensorboard]
+ ] = None,
+ ):
+ """Set the experiment. Will retrieve the Experiment if it exists or create one with the provided name.
Args:
experiment (str):
- Required. Name of the experiment to assign current session with.
+ Required. Name of the experiment to set.
description (str):
Optional. Description of an experiment.
+ backing_tensorboard Union[str, aiplatform.Tensorboard]:
+ Optional. If provided, assigns tensorboard as backing tensorboard to support time series metrics
+ logging.
"""
+ self.reset()
- _MetadataStore.get_or_create()
- context = _Context.get_or_create(
- resource_id=experiment,
- display_name=experiment,
- description=description,
- schema_title=constants.SYSTEM_EXPERIMENT,
- schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_EXPERIMENT],
- metadata=constants.EXPERIMENT_METADATA,
+ experiment = experiment_resources.Experiment.get_or_create(
+ experiment_name=experiment, description=description
)
- if context.schema_title != constants.SYSTEM_EXPERIMENT:
- raise ValueError(
- f"Experiment name {experiment} has been used to create other type of resources "
- f"({context.schema_title}) in this MetadataStore, please choose a different experiment name."
- )
- if description and context.description != description:
- context.update(metadata=context.metadata, description=description)
+ if backing_tensorboard:
+ experiment.assign_backing_tensorboard(tensorboard=backing_tensorboard)
+
+ self._experiment = experiment
+
+ def start_run(
+ self,
+ run: str,
+ *,
+ tensorboard: Union[tensorboard_resource.Tensorboard, str, None] = None,
+ resume=False,
+ ) -> experiment_run_resource.ExperimentRun:
+ """Start a run to current session.
+
+ ```
+ aiplatform.init(experiment='my-experiment')
+ aiplatform.start_run('my-run')
+ aiplatform.log_params({'learning_rate':0.1})
+ ```
+
+ Use as context manager. Run will be ended on context exit:
+ ```
+ aiplatform.init(experiment='my-experiment')
+ with aiplatform.start_run('my-run') as my_run:
+ my_run.log_params({'learning_rate':0.1})
+ ```
+
+ Resume a previously started run:
+ ```
+ aiplatform.init(experiment='my-experiment')
+ with aiplatform.start_run('my-run') as my_run:
+ my_run.log_params({'learning_rate':0.1})
+ ```
- self._experiment = context
-
- def start_run(self, run: str):
- """Setup a run to current session.
Args:
- run (str):
+ run(str):
Required. Name of the run to assign current session with.
- Raise:
- ValueError if experiment is not set. Or if run execution or metrics artifact
- is already created but with a different schema.
+ tensorboard Union[str, tensorboard_resource.Tensorboard]:
+ Optional. Backing Tensorboard Resource to enable and store time series metrics
+ logged to this Experiment Run using `log_time_series_metrics`.
+
+ If not provided will the the default backing tensorboard of the currently
+ set experiment.
+ resume (bool):
+ Whether to resume this run. If False a new run will be created.
+ Raises:
+ ValueError:
+ if experiment is not set. Or if run execution or metrics artifact is already created
+ but with a different schema.
"""
if not self._experiment:
raise ValueError(
"No experiment set for this run. Make sure to call aiplatform.init(experiment='my-experiment') "
- "before trying to start_run. "
+ "before invoking start_run. "
)
- run_execution_id = f"{self._experiment.name}-{run}"
- run_execution = _Execution.get_or_create(
- resource_id=run_execution_id,
- display_name=run,
- schema_title=constants.SYSTEM_RUN,
- schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN],
- )
- if run_execution.schema_title != constants.SYSTEM_RUN:
- raise ValueError(
- f"Run name {run} has been used to create other type of resources ({run_execution.schema_title}) "
- "in this MetadataStore, please choose a different run name."
+
+ if self._experiment_run:
+ self.end_run()
+
+ if resume:
+ self._experiment_run = experiment_run_resource.ExperimentRun(
+ run_name=run, experiment=self._experiment
)
- self._experiment.add_artifacts_and_executions(
- execution_resource_names=[run_execution.resource_name]
- )
+ if tensorboard:
+ self._experiment_run.assign_backing_tensorboard(tensorboard=tensorboard)
- metrics_artifact_id = f"{self._experiment.name}-{run}-metrics"
- metrics_artifact = _Artifact.get_or_create(
- resource_id=metrics_artifact_id,
- display_name=metrics_artifact_id,
- schema_title=constants.SYSTEM_METRICS,
- schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS],
- )
- if metrics_artifact.schema_title != constants.SYSTEM_METRICS:
- raise ValueError(
- f"Run name {run} has been used to create other type of resources ({metrics_artifact.schema_title}) "
- "in this MetadataStore, please choose a different run name."
+ self._experiment_run.update_state(state=gapic.Execution.State.RUNNING)
+
+ else:
+ self._experiment_run = experiment_run_resource.ExperimentRun.create(
+ run_name=run, experiment=self._experiment, tensorboard=tensorboard
)
- run_execution.add_artifact(
- artifact_resource_name=metrics_artifact.resource_name, input=False
- )
- self._run = run_execution
- self._metrics = metrics_artifact
+ return self._experiment_run
+
+ def end_run(self, state: gapic.Execution.State = gapic.Execution.State.COMPLETE):
+ """Ends the the current experiment run.
+
+ ```
+ aiplatform.start_run('my-run')
+ ...
+ aiplatform.end_run()
+ ```
+
+ """
+ self._validate_experiment_and_run(method_name="end_run")
+ try:
+ self._experiment_run.end_run(state=state)
+ except exceptions.NotFound:
+ _LOGGER.warn(
+ f"Experiment run {self._experiment_run.name} was not found."
+ "It may have been deleted"
+ )
+ finally:
+ self._experiment_run = None
def log_params(self, params: Dict[str, Union[float, int, str]]):
"""Log single or multiple parameters with specified key and value pairs.
+ Parameters with the same key will be overwritten.
+
+ ```
+ aiplatform.start_run('my-run')
+ aiplatform.log_params({'learning_rate': 0.1, 'dropout_rate': 0.2})
+ ```
+
Args:
- params (Dict):
+ params (Dict[str, Union[float, int, str]]):
Required. Parameter key/value pairs.
"""
self._validate_experiment_and_run(method_name="log_params")
# query the latest run execution resource before logging.
- execution = _Execution.get_or_create(
- resource_id=self._run.name,
- schema_title=constants.SYSTEM_RUN,
- schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_RUN],
- )
- execution.update(metadata=params)
+ self._experiment_run.log_params(params=params)
- def log_metrics(self, metrics: Dict[str, Union[float, int]]):
+ def log_metrics(self, metrics: Dict[str, Union[float, int, str]]):
"""Log single or multiple Metrics with specified key and value pairs.
+ Metrics with the same key will be overwritten.
+
+ ```
+ aiplatform.start_run('my-run', experiment='my-experiment')
+ aiplatform.log_metrics({'accuracy': 0.9, 'recall': 0.8})
+ ```
+
Args:
- metrics (Dict):
- Required. Metrics key/value pairs. Only flot and int are supported format for value.
- Raises:
- TypeError if value contains unsupported types.
- ValueError if Experiment or Run is not set.
+ metrics (Dict[str, Union[float, int, str]]):
+ Required. Metrics key/value pairs.
"""
self._validate_experiment_and_run(method_name="log_metrics")
- self._validate_metrics_value_type(metrics)
# query the latest metrics artifact resource before logging.
- artifact = _Artifact.get_or_create(
- resource_id=self._metrics.name,
- schema_title=constants.SYSTEM_METRICS,
- schema_version=constants.SCHEMA_VERSIONS[constants.SYSTEM_METRICS],
- )
- artifact.update(metadata=metrics)
+ self._experiment_run.log_metrics(metrics=metrics)
- def get_experiment_df(
- self, experiment: Optional[str] = None
- ) -> "pd.DataFrame": # noqa: F821
- """Returns a Pandas DataFrame of the parameters and metrics associated with one experiment.
-
- Example:
-
- aiplatform.init(experiment='exp-1')
- aiplatform.start_run(run='run-1')
- aiplatform.log_params({'learning_rate': 0.1})
- aiplatform.log_metrics({'accuracy': 0.9})
+ def _validate_experiment_and_run(self, method_name: str):
+ """Validates Experiment and Run are set and raises informative error message.
- aiplatform.start_run(run='run-2')
- aiplatform.log_params({'learning_rate': 0.2})
- aiplatform.log_metrics({'accuracy': 0.95})
+ Args:
+ method_name: The name of th method to raise from.
- Will result in the following DataFrame
- ___________________________________________________________________________
- | experiment_name | run_name | param.learning_rate | metric.accuracy |
- ---------------------------------------------------------------------------
- | exp-1 | run-1 | 0.1 | 0.9 |
- | exp-1 | run-2 | 0.2 | 0.95 |
- ---------------------------------------------------------------------------
+ Raises:
+ ValueError: If Experiment or Run are not set.
+ """
- Args:
- experiment (str):
- Name of the Experiment to filter results. If not set, return results of current active experiment.
+ if not self._experiment:
+ raise ValueError(
+ f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') "
+ f"before trying to {method_name}. "
+ )
+ if not self._experiment_run:
+ raise ValueError(
+ f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. "
+ )
- Returns:
- Pandas Dataframe of Experiment with metrics and parameters.
+ def get_experiment_df(
+ self, experiment: Optional[str] = None
+ ) -> "pd.DataFrame": # noqa: F821
+ """Returns a Pandas DataFrame of the parameters and metrics associated with one experiment.
- Raise:
- NotFound exception if experiment does not exist.
- ValueError if given experiment is not associated with a wrong schema.
- """
+ Example:
- if not experiment:
- experiment = self._experiment.name
+ aiplatform.init(experiment='exp-1')
+ aiplatform.start_run(run='run-1')
+ aiplatform.log_params({'learning_rate': 0.1})
+ aiplatform.log_metrics({'accuracy': 0.9})
- source = "experiment"
- experiment_resource_name = self._get_experiment_or_pipeline_resource_name(
- name=experiment, source=source, expected_schema=constants.SYSTEM_EXPERIMENT,
- )
+ aiplatform.start_run(run='run-2')
+ aiplatform.log_params({'learning_rate': 0.2})
+ aiplatform.log_metrics({'accuracy': 0.95})
- return self._query_runs_to_data_frame(
- context_id=experiment,
- context_resource_name=experiment_resource_name,
- source=source,
- )
+ aiplatform.get_experiments_df()
- def get_pipeline_df(self, pipeline: str) -> "pd.DataFrame": # noqa: F821
- """Returns a Pandas DataFrame of the parameters and metrics associated with one pipeline.
+ Will result in the following DataFrame
+ ___________________________________________________________________________
+ | experiment_name | run_name | param.learning_rate | metric.accuracy |
+ ---------------------------------------------------------------------------
+ | exp-1 | run-1 | 0.1 | 0.9 |
+ | exp-1 | run-2 | 0.2 | 0.95 |
+ ---------------------------------------------------------------------------
Args:
- pipeline: Name of the Pipeline to filter results.
+ experiment (str):
+ Name of the Experiment to filter results. If not set, return results of current active experiment.
Returns:
- Pandas Dataframe of Pipeline with metrics and parameters.
+ Pandas Dataframe of Experiment with metrics and parameters.
Raise:
NotFound exception if experiment does not exist.
ValueError if given experiment is not associated with a wrong schema.
"""
- source = "pipeline"
- pipeline_resource_name = self._get_experiment_or_pipeline_resource_name(
- name=pipeline, source=source, expected_schema=constants.SYSTEM_PIPELINE
- )
-
- return self._query_runs_to_data_frame(
- context_id=pipeline,
- context_resource_name=pipeline_resource_name,
- source=source,
- )
-
- def _validate_experiment_and_run(self, method_name: str):
- if not self._experiment:
- raise ValueError(
- f"No experiment set. Make sure to call aiplatform.init(experiment='my-experiment') "
- f"before trying to {method_name}. "
- )
- if not self._run:
- raise ValueError(
- f"No run set. Make sure to call aiplatform.start_run('my-run') before trying to {method_name}. "
- )
-
- @staticmethod
- def _validate_metrics_value_type(metrics: Dict[str, Union[float, int]]):
- """Verify that metrics value are with supported types.
+ if not experiment:
+ experiment = self._experiment
+ else:
+ experiment = experiment_resources.Experiment(experiment)
+
+ return experiment.get_data_frame()
+
+ def log(
+ self,
+ *,
+ pipeline_job: Optional[pipeline_jobs.PipelineJob] = None,
+ ):
+ """Log Vertex AI Resources to the current experiment run.
+
+ ```
+ aiplatform.start_run('my-run')
+ my_job = aiplatform.PipelineJob(...)
+ my_job.submit()
+ aiplatform.log(my_job)
+ ```
Args:
- metrics (Dict):
- Required. Metrics key/value pairs. Only flot and int are supported format for value.
- Raises:
- TypeError if value contains unsupported types.
+ pipeline_job (pipeline_jobs.PipelineJob):
+ Optional. Vertex PipelineJob to associate to this Experiment Run.
"""
+ self._validate_experiment_and_run(method_name="log")
+ self._experiment_run.log(pipeline_job=pipeline_job)
- for key, value in metrics.items():
- if isinstance(value, int) or isinstance(value, float):
- continue
- raise TypeError(
- f"metrics contain unsupported value types. key: {key}; value: {value}; type: {type(value)}"
- )
+ def log_time_series_metrics(
+ self,
+ metrics: Dict[str, Union[float]],
+ step: Optional[int] = None,
+ wall_time: Optional[timestamp_pb2.Timestamp] = None,
+ ):
+ """Logs time series metrics to to this Experiment Run.
- @staticmethod
- def _get_experiment_or_pipeline_resource_name(
- name: str, source: str, expected_schema: str
- ) -> str:
- """Get the full resource name of the Context representing an Experiment or Pipeline.
+ Requires the experiment or experiment run has a backing Vertex Tensorboard resource.
- Args:
- name (str):
- Name of the Experiment or Pipeline.
- source (str):
- Identify whether the this is an Experiment or a Pipeline.
- expected_schema (str):
- expected_schema identifies the expected schema used for Experiment or Pipeline.
+ ```
+ my_tensorboard = aiplatform.Tensorboard(...)
+ aiplatform.init(experiment='my-experiment', experiment_tensorboard=my_tensorboard)
+ aiplatform.start_run('my-run')
- Returns:
- The full resource name of the Experiment or Pipeline Context.
+ # increments steps as logged
+ for i in range(10):
+ aiplatform.log_time_series_metrics({'loss': loss})
- Raise:
- NotFound exception if experiment or pipeline does not exist.
- """
+ # explicitly log steps
+ for i in range(10):
+ aiplatform.log_time_series_metrics({'loss': loss}, step=i)
+ ```
- context = _Context(resource_name=name)
+ Args:
+ metrics (Dict[str, Union[str, float]]):
+ Required. Dictionary of where keys are metric names and values are metric values.
+ step (int):
+ Optional. Step index of this data point within the run.
- if context.schema_title != expected_schema:
- raise ValueError(
- f"Please provide a valid {source} name. {name} is not a {source}."
- )
- return context.resource_name
+ If not provided, the latest
+ step amongst all time series metrics already logged will be used.
+ wall_time (timestamp_pb2.Timestamp):
+ Optional. Wall clock timestamp when this data point is
+ generated by the end user.
- def _query_runs_to_data_frame(
- self, context_id: str, context_resource_name: str, source: str
- ) -> "pd.DataFrame": # noqa: F821
- """Get metrics and parameters associated with a given Context into a Dataframe.
+ If not provided, this will be generated based on the value from time.time()
+
+ Raises:
+ RuntimeError: If current experiment run doesn't have a backing Tensorboard resource.
+ """
+ self._validate_experiment_and_run(method_name="log_time_series_metrics")
+ self._experiment_run.log_time_series_metrics(
+ metrics=metrics, step=step, wall_time=wall_time
+ )
+ def start_execution(
+ self,
+ *,
+ schema_title: Optional[str] = None,
+ display_name: Optional[str] = None,
+ resource_id: Optional[str] = None,
+ metadata: Optional[Dict[str, Any]] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ resume: bool = False,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> execution.Execution:
+ """
+ Create and starts a new Metadata Execution or resumes a previously created Execution.
+
+ To start a new execution:
+
+ ```
+ with aiplatform.start_execution(schema_title='system.ContainerExecution', display_name='trainer) as exc:
+ exc.assign_input_artifacts([my_artifact])
+ model = aiplatform.Artifact.create(uri='gs://my-uri', schema_title='system.Model')
+ exc.assign_output_artifacts([model])
+ ```
+
+ To continue a previously created execution:
+ ```
+ with aiplatform.start_execution(resource_id='my-exc', resume=True) as exc:
+ ...
+ ```
Args:
- context_id (str):
- Name of the Experiment or Pipeline.
- context_resource_name (str):
- Full resource name of the Context associated with an Experiment or Pipeline.
- source (str):
- Identify whether the this is an Experiment or a Pipeline.
+ schema_title (str):
+ Optional. schema_title identifies the schema title used by the Execution. Required if starting
+ a new Execution.
+ resource_id (str):
+ Optional. The portion of the Execution name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Execution. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Execution. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Execution. Overrides
+ credentials set in aiplatform.init.
Returns:
- The full resource name of the Experiment or Pipeline Context.
- """
+ Execution: Instantiated representation of the managed Metadata Execution.
- filter = f'schema_title="{constants.SYSTEM_RUN}" AND in_context("{context_resource_name}")'
- run_executions = _Execution.list(filter=filter)
+ Raises:
+ ValueError: If experiment run is set and project or location do not match experiment run.
+ ValueError: If resume set to `True` and resource_id is not provided.
+ ValueError: If creating a new executin and schema_title is not provided.
+ """
- context_summary = []
- for run_execution in run_executions:
- run_dict = {
- f"{source}_name": context_id,
- "run_name": run_execution.display_name,
- }
- run_dict.update(
- self._execution_to_column_named_metadata(
- "param", run_execution.metadata
+ if (
+ self._experiment_run
+ and not self._experiment_run._is_legacy_experiment_run()
+ ):
+ if project and project != self._experiment_run.project:
+ raise ValueError(
+ f"Currently set Experiment run project {self._experiment_run.project} must"
+ f"match provided project {project}"
)
- )
-
- for metric_artifact in run_execution.query_input_and_output_artifacts():
- run_dict.update(
- self._execution_to_column_named_metadata(
- "metric", metric_artifact.metadata
- )
+ if location and location != self._experiment_run.location:
+ raise ValueError(
+ f"Currently set Experiment run location {self._experiment_run.location} must"
+ f"match provided location {project}"
)
- context_summary.append(run_dict)
+ if resume:
+ if not resource_id:
+ raise ValueError("resource_id is required when resume=True")
- try:
- import pandas as pd
- except ImportError:
- raise ImportError(
- "Pandas is not installed and is required to get dataframe as the return format. "
- 'Please install the SDK using "pip install python-aiplatform[full]"'
+ run_execution = execution.Execution(
+ execution_name=resource_id,
+ project=project,
+ location=location,
+ credentials=credentials,
)
- return pd.DataFrame(context_summary)
+ # TODO(handle updates if resuming)
- @staticmethod
- def _execution_to_column_named_metadata(
- metadata_type: str, metadata: Dict,
- ) -> Dict[str, Union[int, float, str]]:
- """Returns a dict of the Execution/Artifact metadata with column names.
+ run_execution.update(state=gca_execution.Execution.State.RUNNING)
+ else:
+ if not schema_title:
+ raise ValueError(
+ "schema_title must be provided when starting a new Execution"
+ )
- Args:
- metadata_type: The type of this execution properties (param, metric).
- metadata: Either an Execution or Artifact metadata field.
+ run_execution = execution.Execution.create(
+ display_name=display_name,
+ schema_title=schema_title,
+ schema_version=schema_version,
+ metadata=metadata,
+ description=description,
+ resource_id=resource_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
- Returns:
- Dict of custom properties with keys mapped to column names
- """
+ if self.experiment_run:
+ if self.experiment_run._is_legacy_experiment_run():
+ _LOGGER.warn(
+ f"{self.experiment_run._run_name} is an Experiment run created in Vertex Experiment Preview",
+ " and does not support tracking Executions."
+ " Please create a new Experiment run to track executions against an Experiment run.",
+ )
+ else:
+ self.experiment_run.associate_execution(run_execution)
+ run_execution.assign_input_artifacts = (
+ self.experiment_run._association_wrapper(
+ run_execution.assign_input_artifacts
+ )
+ )
+ run_execution.assign_output_artifacts = (
+ self.experiment_run._association_wrapper(
+ run_execution.assign_output_artifacts
+ )
+ )
- return {
- ".".join([metadata_type, key]): value for key, value in metadata.items()
- }
+ return run_execution
-metadata_service = _MetadataService()
+_experiment_tracker = _ExperimentTracker()
diff --git a/google/cloud/aiplatform/metadata/metadata_store.py b/google/cloud/aiplatform/metadata/metadata_store.py
index 494d31aca4..2f0c8e2955 100644
--- a/google/cloud/aiplatform/metadata/metadata_store.py
+++ b/google/cloud/aiplatform/metadata/metadata_store.py
@@ -24,7 +24,7 @@
from google.cloud.aiplatform import base, initializer
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import utils
-from google.cloud.aiplatform_v1beta1.types import metadata_store as gca_metadata_store
+from google.cloud.aiplatform.compat.types import metadata_store as gca_metadata_store
class _MetadataStore(base.VertexAiResourceNounWithFutureManager):
@@ -35,6 +35,8 @@ class _MetadataStore(base.VertexAiResourceNounWithFutureManager):
_resource_noun = "metadataStores"
_getter_method = "get_metadata_store"
_delete_method = "delete_metadata_store"
+ _parse_resource_name_method = "parse_metadata_store_path"
+ _format_resource_name_method = "metadata_store_path"
def __init__(
self,
@@ -64,7 +66,9 @@ def __init__(
"""
super().__init__(
- project=project, location=location, credentials=credentials,
+ project=project,
+ location=location,
+ credentials=credentials,
)
self._gca_resource = self._get_gca_resource(resource_name=metadata_store_name)
@@ -77,7 +81,7 @@ def get_or_create(
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
) -> "_MetadataStore":
- """"Retrieves or Creates (if it does not exist) a Metadata Store.
+ """ "Retrieves or Creates (if it does not exist) a Metadata Store.
Args:
metadata_store_id (str):
@@ -176,7 +180,7 @@ def _create(
gapic_metadata_store = gca_metadata_store.MetadataStore(
encryption_spec=initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
- select_version=compat.V1BETA1,
+ select_version=compat.DEFAULT_VERSION,
)
)
@@ -205,7 +209,7 @@ def _get(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
- ) -> "Optional[_MetadataStore]":
+ ) -> Optional["_MetadataStore"]:
"""Returns a MetadataStore resource.
Args:
@@ -238,3 +242,43 @@ def _get(
)
except exceptions.NotFound:
logging.info(f"MetadataStore {metadata_store_name} not found.")
+
+ @classmethod
+ def ensure_default_metadata_store_exists(
+ cls,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ encryption_key_spec_name: Optional[str] = None,
+ ):
+ """Helpers method to ensure the `default` MetadataStore exists in this project and location.
+
+ Args:
+ project (str):
+ Optional. Project to retrieve resource from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve resource from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+ encryption_spec_key_name (str):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the metadata store. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this MetadataStore and all sub-resources of this MetadataStore will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ """
+
+ cls.get_or_create(
+ project=project,
+ location=location,
+ credentials=credentials,
+ encryption_spec_key_name=encryption_key_spec_name,
+ )
diff --git a/google/cloud/aiplatform/metadata/resource.py b/google/cloud/aiplatform/metadata/resource.py
index 85ac419d40..89c145dcbe 100644
--- a/google/cloud/aiplatform/metadata/resource.py
+++ b/google/cloud/aiplatform/metadata/resource.py
@@ -16,34 +16,37 @@
#
import abc
-import logging
+import collections
import re
from copy import deepcopy
-from typing import Optional, Dict, Union, Sequence
+from typing import Dict, Optional, Union, Any, List
import proto
from google.api_core import exceptions
from google.auth import credentials as auth_credentials
-from google.protobuf import json_format
from google.cloud.aiplatform import base, initializer
+from google.cloud.aiplatform import metadata
from google.cloud.aiplatform import utils
-from google.cloud.aiplatform_v1beta1 import Artifact as GapicArtifact
-from google.cloud.aiplatform_v1beta1 import Context as GapicContext
-from google.cloud.aiplatform_v1beta1 import Execution as GapicExecution
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.compat.types import context as gca_context
+from google.cloud.aiplatform.compat.types import execution as gca_execution
+
+_LOGGER = base.Logger(__name__)
class _Resource(base.VertexAiResourceNounWithFutureManager, abc.ABC):
"""Metadata Resource for Vertex AI"""
client_class = utils.MetadataClientWithOverride
- _is_client_prediction_client = False
_delete_method = None
def __init__(
self,
resource_name: Optional[str] = None,
- resource: Optional[Union[GapicContext, GapicArtifact, GapicExecution]] = None,
+ resource: Optional[
+ Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]
+ ] = None,
metadata_store_id: str = "default",
project: Optional[str] = None,
location: Optional[str] = None,
@@ -57,7 +60,7 @@ def __init__(
Example: "projects/123/locations/us-central1/metadataStores/default//my-resource".
or "my-resource" when project and location are initialized or passed. if ``resource`` is provided, this
should not be set.
- resource (Union[GapicContext, GapicArtifact, GapicExecution]):
+ resource (Union[gca_context.Context, gca_artifact.Artifact, gca_execution.Execution]):
The proto.Message that contains the full information of the resource. If both set, this field overrides
``resource_name`` field.
metadata_store_id (str):
@@ -70,35 +73,38 @@ def __init__(
Optional location to retrieve the resource from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
- Custom credentials to use to upload this model. Overrides
+ Custom credentials to use to retrieve this resource. Overrides
credentials set in aiplatform.init.
"""
super().__init__(
- project=project, location=location, credentials=credentials,
+ project=project,
+ location=location,
+ credentials=credentials,
)
if resource:
self._gca_resource = resource
- return
-
- full_resource_name = resource_name
- # Construct the full_resource_name if input resource_name is the resource_id
- if "/" not in resource_name:
+ else:
full_resource_name = utils.full_resource_name(
resource_name=resource_name,
- resource_noun=f"metadataStores/{metadata_store_id}/{self._resource_noun}",
+ resource_noun=self._resource_noun,
+ parse_resource_name_method=self._parse_resource_name,
+ format_resource_name_method=self._format_resource_name,
+ parent_resource_name_fields={
+ metadata.metadata_store._MetadataStore._resource_noun: metadata_store_id
+ },
project=self.project,
location=self.location,
)
- self._gca_resource = getattr(self.api_client, self._getter_method)(
- name=full_resource_name
- )
+ self._gca_resource = getattr(self.api_client, self._getter_method)(
+ name=full_resource_name, retry=base._DEFAULT_RETRY
+ )
@property
def metadata(self) -> Dict:
- return json_format.MessageToDict(self._gca_resource._pb)["metadata"]
+ return self.to_dict()["metadata"]
@property
def schema_title(self) -> str:
@@ -168,7 +174,7 @@ def get_or_create(
credentials=credentials,
)
if not resource:
- logging.info(f"Creating Resource {resource_id}")
+ _LOGGER.info(f"Creating Resource {resource_id}")
resource = cls._create(
resource_id=resource_id,
schema_title=schema_title,
@@ -183,9 +189,45 @@ def get_or_create(
)
return resource
+ def sync_resource(self):
+ """Syncs local resource with the resource in metadata store."""
+ self._gca_resource = getattr(self.api_client, self._getter_method)(
+ name=self.resource_name, retry=base._DEFAULT_RETRY
+ )
+
+ @staticmethod
+ def _nested_update_metadata(
+ gca_resource: Union[
+ gca_context.Context, gca_execution.Execution, gca_artifact.Artifact
+ ],
+ metadata: Optional[Dict[str, Any]] = None,
+ ):
+ """Helper method to update gca_resource in place.
+
+ Performs a one-level deep nested update on the metadata field.
+
+ Args:
+ gca_resource (Union[gca_context.Context, gca_execution.Execution, gca_artifact.Artifact]):
+ Required. Metadata Protobuf resource. This proto's metadata will be
+ updated in place.
+ metadata (Dict[str, Any]):
+ Optional. Metadata dictionary to merge into gca_resource.metadata.
+ """
+
+ if metadata:
+ if gca_resource.metadata:
+ for key, value in metadata.items():
+ # Note: This only support nested dictionaries one level deep
+ if isinstance(value, collections.abc.Mapping):
+ gca_resource.metadata[key].update(value)
+ else:
+ gca_resource.metadata[key] = value
+ else:
+ gca_resource.metadata = metadata
+
def update(
self,
- metadata: Dict,
+ metadata: Optional[Dict] = None,
description: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
@@ -193,27 +235,26 @@ def update(
Args:
metadata (Dict):
- Required. metadata contains the updated metadata information.
+ Optional. metadata contains the updated metadata information.
description (str):
Optional. Description describes the resource to be updated.
credentials (auth_credentials.Credentials):
Custom credentials to use to update this resource. Overrides
credentials set in aiplatform.init.
-
"""
gca_resource = deepcopy(self._gca_resource)
- if gca_resource.metadata:
- gca_resource.metadata.update(metadata)
- else:
- gca_resource.metadata = metadata
+ if metadata:
+ self._nested_update_metadata(gca_resource=gca_resource, metadata=metadata)
if description:
gca_resource.description = description
api_client = self._instantiate_client(credentials=credentials)
+ # TODO: if etag is not valid sync and retry
update_gca_resource = self._update_resource(
- client=api_client, resource=gca_resource,
+ client=api_client,
+ resource=gca_resource,
)
self._gca_resource = update_gca_resource
@@ -225,7 +266,7 @@ def list(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
- ) -> Sequence["_Resource"]:
+ ) -> List["_Resource"]:
"""List Metadata resources that match the list filter in target metadataStore.
Args:
@@ -252,8 +293,6 @@ def list(
a list of managed Metadata resource.
"""
- api_client = cls._instantiate_client(location=location, credentials=credentials)
-
parent = (
initializer.global_config.common_location_path(
project=project, location=location
@@ -261,25 +300,13 @@ def list(
+ f"/metadataStores/{metadata_store_id}"
)
- try:
- resources = cls._list_resources(
- client=api_client, parent=parent, filter=filter,
- )
- except exceptions.NotFound:
- logging.info(
- f"No matching resources in metadataStore: {metadata_store_id} with filter: {filter}"
- )
- return []
-
- return [
- cls(
- resource=resource,
- project=project,
- location=location,
- credentials=credentials,
- )
- for resource in resources
- ]
+ return super().list(
+ filter=filter,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=parent,
+ )
@classmethod
def _create(
@@ -294,7 +321,7 @@ def _create(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
- ):
+ ) -> Optional["_Resource"]:
"""Creates a new Metadata resource.
Args:
@@ -354,16 +381,19 @@ def _create(
metadata=metadata,
)
except exceptions.AlreadyExists:
- logging.info(f"Resource '{resource_id}' already exist")
+ _LOGGER.info(f"Resource '{resource_id}' already exist")
return
- return cls(
- resource=resource,
+ self = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
)
+ self._gca_resource = resource
+
+ return self
+
@classmethod
def _get(
cls,
@@ -403,14 +433,14 @@ def _get(
try:
return cls(
- resource_name=resource_name,
+ resource_name,
metadata_store_id=metadata_store_id,
project=project,
location=location,
credentials=credentials,
)
except exceptions.NotFound:
- logging.info(f"Resource {resource_name} not found.")
+ _LOGGER.info(f"Resource {resource_name} not found.")
@classmethod
@abc.abstractmethod
@@ -431,7 +461,9 @@ def _create_resource(
@classmethod
@abc.abstractmethod
def _update_resource(
- cls, client: utils.MetadataClientWithOverride, resource: proto.Message,
+ cls,
+ client: utils.MetadataClientWithOverride,
+ resource: proto.Message,
) -> proto.Message:
"""Update resource method."""
pass
@@ -450,7 +482,7 @@ def _extract_metadata_store_id(resource_name, resource_noun) -> str:
metadata_store_id (str):
The metadata store id for the particular resource name.
Raises:
- ValueError if it does not exist.
+ ValueError: If it does not exist.
"""
pattern = re.compile(
r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/metadataStores\/(?P[\w-]+)\/"
diff --git a/google/cloud/aiplatform/metadata/schema/base_artifact.py b/google/cloud/aiplatform/metadata/schema/base_artifact.py
new file mode 100644
index 0000000000..c89d989edd
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/base_artifact.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+
+from typing import Optional, Dict
+
+from google.auth import credentials as auth_credentials
+
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import constants
+
+
+class BaseArtifactSchema(metaclass=abc.ABCMeta):
+ """Base class for Metadata Artifact types."""
+
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def schema_title(cls) -> str:
+ """Identifies the Vertex Metadta schema title used by the resource."""
+ pass
+
+ def __init__(
+ self,
+ *,
+ artifact_id: Optional[str] = None,
+ uri: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+
+ """Initializes the Artifact with the given name, URI and metadata.
+
+ This is the base class for defining various artifact types, which can be
+ passed to google.Artifact to create a corresponding resource.
+ Artifacts carry a `metadata` field, which is a dictionary for storing
+ metadata related to this artifact. Subclasses from ArtifactType can enforce
+ various structure and field requirements for the metadata field.
+
+ Args:
+ resource_id (str):
+ Optional. The portion of the Artifact name with
+ the following format, this is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ self.artifact_id = artifact_id
+ self.uri = uri
+ self.display_name = display_name
+ self.schema_version = schema_version or constants._DEFAULT_SCHEMA_VERSION
+ self.description = description
+ self.metadata = metadata
+ self.state = state
+
+ def create(
+ self,
+ *,
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "artifact.Artifact":
+ """Creates a new Metadata Artifact.
+
+ Args:
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//artifacts/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Artifact. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Artifact. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Artifact. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ Artifact: Instantiated representation of the managed Metadata Artifact.
+ """
+ return artifact.Artifact.create_from_base_artifact_schema(
+ base_artifact_schema=self,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
diff --git a/google/cloud/aiplatform/metadata/schema/base_execution.py b/google/cloud/aiplatform/metadata/schema/base_execution.py
new file mode 100644
index 0000000000..811b7d9791
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/base_execution.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+
+from typing import Optional, Dict
+
+from google.auth import credentials as auth_credentials
+
+from google.cloud.aiplatform.compat.types import execution as gca_execution
+from google.cloud.aiplatform.metadata import constants
+from google.cloud.aiplatform.metadata import execution
+
+
+class BaseExecutionSchema(metaclass=abc.ABCMeta):
+ """Base class for Metadata Execution schema."""
+
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def schema_title(cls) -> str:
+ """Identifies the Vertex Metadta schema title used by the resource."""
+ pass
+
+ def __init__(
+ self,
+ *,
+ state: Optional[
+ gca_execution.Execution.State
+ ] = gca_execution.Execution.State.RUNNING,
+ execution_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ description: Optional[str] = None,
+ ):
+
+ """Initializes the Execution with the given name, URI and metadata.
+
+ Args:
+ state (gca_execution.Execution.State.RUNNING):
+ Optional. State of this Execution. Defaults to RUNNING.
+ execution_id (str):
+ Optional. The portion of the Execution name with
+ the following format, this is globally unique in a metadataStore.
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ """
+ self.state = state
+ self.execution_id = execution_id
+ self.display_name = display_name
+ self.schema_version = schema_version or constants._DEFAULT_SCHEMA_VERSION
+ self.metadata = metadata
+ self.description = description
+
+ def create(
+ self,
+ *,
+ metadata_store_id: Optional[str] = "default",
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "execution.Execution":
+ """Creates a new Metadata Execution.
+
+ Args:
+ metadata_store_id (str):
+ Optional. The portion of the resource name with
+ the format:
+ projects/123/locations/us-central1/metadataStores//executions/
+ If not provided, the MetadataStore's ID will be set to "default".
+ project (str):
+ Optional. Project used to create this Execution. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location used to create this Execution. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials used to create this Execution. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ Execution: Instantiated representation of the managed Metadata Execution.
+
+ """
+ self.execution = execution.Execution.create_from_base_execution_schema(
+ base_execution_schema=self,
+ metadata_store_id=metadata_store_id,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ return self.execution
diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
new file mode 100644
index 0000000000..99e0fb0ba6
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py
@@ -0,0 +1,270 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from typing import Optional, Dict
+
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.metadata.schema import base_artifact
+from google.cloud.aiplatform.metadata.schema import utils
+
+# The artifact property key for the resource_name
+_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME = "resourceName"
+
+
+class VertexDataset(base_artifact.BaseArtifactSchema):
+ """An artifact representing a Vertex Dataset."""
+
+ schema_title = "google.VertexDataset"
+
+ def __init__(
+ self,
+ *,
+ vertex_dataset_name: str,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ vertex_dataset_name (str):
+ The name of the Dataset resource, in a form of
+ projects/{project}/locations/{location}/datasets/{dataset}. For
+ more details, see
+ https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.datasets/get
+ This is used to generate the resource uri as follows:
+ https://{service-endpoint}/v1/{dataset_name},
+ where {service-endpoint} is one of the supported service endpoints at
+ https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_dataset_name
+
+ super(VertexDataset, self).__init__(
+ uri=utils.create_uri_from_resource_name(resource_name=vertex_dataset_name),
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class VertexModel(base_artifact.BaseArtifactSchema):
+ """An artifact representing a Vertex Model."""
+
+ schema_title = "google.VertexModel"
+
+ def __init__(
+ self,
+ *,
+ vertex_model_name: str,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ vertex_model_name (str):
+ The name of the Model resource, in a form of
+ projects/{project}/locations/{location}/models/{model}. For
+ more details, see
+ https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models/get
+ This is used to generate the resource uri as follows:
+ https://{service-endpoint}/v1/{vertex_model_name},
+ where {service-endpoint} is one of the supported service endpoints at
+ https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_model_name
+
+ super(VertexModel, self).__init__(
+ uri=utils.create_uri_from_resource_name(resource_name=vertex_model_name),
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class VertexEndpoint(base_artifact.BaseArtifactSchema):
+ """An artifact representing a Vertex Endpoint."""
+
+ schema_title = "google.VertexEndpoint"
+
+ def __init__(
+ self,
+ *,
+ vertex_endpoint_name: str,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ vertex_endpoint_name (str):
+ The name of the Endpoint resource, in a form of
+ projects/{project}/locations/{location}/endpoints/{endpoint}. For
+ more details, see
+ https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/get
+ This is used to generate the resource uri as follows:
+ https://{service-endpoint}/v1/{vertex_endpoint_name},
+ where {service-endpoint} is one of the supported service endpoints at
+ https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_endpoint_name
+
+ super(VertexEndpoint, self).__init__(
+ uri=utils.create_uri_from_resource_name(resource_name=vertex_endpoint_name),
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class UnmanagedContainerModel(base_artifact.BaseArtifactSchema):
+ """An artifact representing a Vertex Unmanaged Container Model."""
+
+ schema_title = "google.UnmanagedContainerModel"
+
+ def __init__(
+ self,
+ *,
+ predict_schema_ta: utils.PredictSchemata,
+ container_spec: utils.ContainerSpec,
+ artifact_id: Optional[str] = None,
+ uri: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ predict_schema_ta (PredictSchemata):
+ An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
+ container_spec (ContainerSpec):
+ An instance of ContainerSpec which holds the container configuration for the model.
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ display_name (str):
+ Optional. The user-defined name of the Artifact.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Artifact.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ extended_metadata["predictSchemata"] = predict_schema_ta.to_dict()
+ extended_metadata["containerSpec"] = container_spec.to_dict()
+
+ super(UnmanagedContainerModel, self).__init__(
+ uri=uri,
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
diff --git a/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py
new file mode 100644
index 0000000000..f3491a5573
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+from typing import Optional, Dict
+
+from google.cloud.aiplatform.compat.types import artifact as gca_artifact
+from google.cloud.aiplatform.metadata.schema import base_artifact
+
+
+class Model(base_artifact.BaseArtifactSchema):
+ """Artifact type for model."""
+
+ schema_title = "system.Model"
+
+ def __init__(
+ self,
+ *,
+ uri: Optional[str] = None,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the base.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the base.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(Model, self).__init__(
+ uri=uri,
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class Artifact(base_artifact.BaseArtifactSchema):
+ """A generic artifact."""
+
+ schema_title = "system.Artifact"
+
+ def __init__(
+ self,
+ *,
+ uri: Optional[str] = None,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the base.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the base.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(Artifact, self).__init__(
+ uri=uri,
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class Dataset(base_artifact.BaseArtifactSchema):
+ """An artifact representing a system Dataset."""
+
+ schema_title = "system.Dataset"
+
+ def __init__(
+ self,
+ *,
+ uri: Optional[str] = None,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the base.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the base.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(Dataset, self).__init__(
+ uri=uri,
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
+
+
+class Metrics(base_artifact.BaseArtifactSchema):
+ """Artifact schema for scalar metrics."""
+
+ schema_title = "system.Metrics"
+
+ def __init__(
+ self,
+ *,
+ accuracy: Optional[float] = None,
+ precision: Optional[float] = None,
+ recall: Optional[float] = None,
+ f1score: Optional[float] = None,
+ mean_absolute_error: Optional[float] = None,
+ mean_squared_error: Optional[float] = None,
+ uri: Optional[str] = None,
+ artifact_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ description: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
+ ):
+ """Args:
+ accuracy (float):
+ Optional.
+ precision (float):
+ Optional.
+ recall (float):
+ Optional.
+ f1score (float):
+ Optional.
+ mean_absolute_error (float):
+ Optional.
+ mean_squared_error (float):
+ Optional.
+ uri (str):
+ Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
+ artifact file.
+ artifact_id (str):
+ Optional. The portion of the Artifact name with
+ the format. This is globally unique in a metadataStore:
+ projects/123/locations/us-central1/metadataStores//artifacts/.
+ display_name (str):
+ Optional. The user-defined name of the base.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the base.
+ If not set, defaults to use the latest version.
+ description (str):
+ Optional. Describes the purpose of the Artifact to be created.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Artifact.
+ state (google.cloud.gapic.types.Artifact.State):
+ Optional. The state of this Artifact. This is a
+ property of the Artifact, and does not imply or
+ capture any ongoing process. This property is
+ managed by clients (such as Vertex AI
+ Pipelines), and the system does not prescribe or
+ check the validity of state transitions.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ if accuracy:
+ extended_metadata["accuracy"] = accuracy
+ if precision:
+ extended_metadata["precision"] = precision
+ if recall:
+ extended_metadata["recall"] = recall
+ if f1score:
+ extended_metadata["f1score"] = f1score
+ if mean_absolute_error:
+ extended_metadata["mean_absolute_error"] = mean_absolute_error
+ if mean_squared_error:
+ extended_metadata["mean_squared_error"] = mean_squared_error
+
+ super(Metrics, self).__init__(
+ uri=uri,
+ artifact_id=artifact_id,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ state=state,
+ )
diff --git a/google/cloud/aiplatform/metadata/schema/system/execution_schema.py b/google/cloud/aiplatform/metadata/schema/system/execution_schema.py
new file mode 100644
index 0000000000..68c96902cb
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/system/execution_schema.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+from typing import Optional, Dict
+
+from google.cloud.aiplatform.compat.types import execution as gca_execution
+from google.cloud.aiplatform.metadata.schema import base_execution
+
+
+class ContainerExecution(base_execution.BaseExecutionSchema):
+ """Execution schema for a container execution."""
+
+ schema_title = "system.ContainerExecution"
+
+ def __init__(
+ self,
+ *,
+ state: Optional[
+ gca_execution.Execution.State
+ ] = gca_execution.Execution.State.RUNNING,
+ execution_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ description: Optional[str] = None,
+ ):
+ """Args:
+ state (gca_execution.Execution.State.RUNNING):
+ Optional. State of this Execution. Defaults to RUNNING.
+ execution_id (str):
+ Optional. The portion of the Execution name with
+ the following format, this is globally unique in a metadataStore.
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(ContainerExecution, self).__init__(
+ execution_id=execution_id,
+ state=state,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ )
+
+
+class CustomJobExecution(base_execution.BaseExecutionSchema):
+ """Execution schema for a custom job execution."""
+
+ schema_title = "system.CustomJobExecution"
+
+ def __init__(
+ self,
+ *,
+ state: Optional[
+ gca_execution.Execution.State
+ ] = gca_execution.Execution.State.RUNNING,
+ execution_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ description: Optional[str] = None,
+ ):
+ """Args:
+ state (gca_execution.Execution.State.RUNNING):
+ Optional. State of this Execution. Defaults to RUNNING.
+ execution_id (str):
+ Optional. The portion of the Execution name with
+ the following format, this is globally unique in a metadataStore.
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(CustomJobExecution, self).__init__(
+ execution_id=execution_id,
+ state=state,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ )
+
+
+class Run(base_execution.BaseExecutionSchema):
+ """Execution schema for root run execution."""
+
+ schema_title = "system.Run"
+
+ def __init__(
+ self,
+ *,
+ state: Optional[
+ gca_execution.Execution.State
+ ] = gca_execution.Execution.State.RUNNING,
+ execution_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ schema_version: Optional[str] = None,
+ metadata: Optional[Dict] = None,
+ description: Optional[str] = None,
+ ):
+ """Args:
+ state (gca_execution.Execution.State.RUNNING):
+ Optional. State of this Execution. Defaults to RUNNING.
+ execution_id (str):
+ Optional. The portion of the Execution name with
+ the following format, this is globally unique in a metadataStore.
+ projects/123/locations/us-central1/metadataStores//executions/.
+ display_name (str):
+ Optional. The user-defined name of the Execution.
+ schema_version (str):
+ Optional. schema_version specifies the version used by the Execution.
+ If not set, defaults to use the latest version.
+ metadata (Dict):
+ Optional. Contains the metadata information that will be stored in the Execution.
+ description (str):
+ Optional. Describes the purpose of the Execution to be created.
+ """
+ extended_metadata = copy.deepcopy(metadata) if metadata else {}
+ super(Run, self).__init__(
+ execution_id=execution_id,
+ state=state,
+ display_name=display_name,
+ schema_version=schema_version,
+ description=description,
+ metadata=extended_metadata,
+ )
diff --git a/google/cloud/aiplatform/metadata/schema/utils.py b/google/cloud/aiplatform/metadata/schema/utils.py
new file mode 100644
index 0000000000..72577d9324
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/schema/utils.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import re
+
+from typing import Optional, Dict, List
+from dataclasses import dataclass
+
+
+@dataclass
+class PredictSchemata:
+ """A class holding instance, parameter and prediction schema uris.
+
+ Args:
+ instance_schema_uri (str):
+ Required. Points to a YAML file stored on Google Cloud Storage
+ describing the format of a single instance, which are used in
+ PredictRequest.instances, ExplainRequest.instances and
+ BatchPredictionJob.input_config. The schema is defined as an
+ OpenAPI 3.0.2 `Schema Object.
+ parameters_schema_uri (str):
+ Required. Points to a YAML file stored on Google Cloud Storage
+ describing the parameters of prediction and explanation via
+ PredictRequest.parameters, ExplainRequest.parameters and
+ BatchPredictionJob.model_parameters. The schema is defined as an
+ OpenAPI 3.0.2 `Schema Object.
+ prediction_schema_uri (str):
+ Required. Points to a YAML file stored on Google Cloud Storage
+ describing the format of a single prediction produced by this Model
+ , which are returned via PredictResponse.predictions,
+ ExplainResponse.explanations, and BatchPredictionJob.output_config.
+ The schema is defined as an OpenAPI 3.0.2 `Schema Object.
+ """
+
+ instance_schema_uri: str
+ parameters_schema_uri: str
+ prediction_schema_uri: str
+
+ def to_dict(self):
+ """ML metadata schema dictionary representation of this DataClass"""
+ results = {}
+ results["instanceSchemaUri"] = self.instance_schema_uri
+ results["parametersSchemaUri"] = self.parameters_schema_uri
+ results["predictionSchemaUri"] = self.prediction_schema_uri
+
+ return results
+
+
+@dataclass
+class ContainerSpec:
+ """Container configuration for the model.
+ Args:
+ image_uri (str):
+ Required. URI of the Docker image to be used as the custom
+ container for serving predictions. This URI must identify an image
+ in Artifact Registry or Container Registry.
+ command (Sequence[str]):
+ Optional. Specifies the command that runs when the container
+ starts. This overrides the container's `ENTRYPOINT`.
+ args (Sequence[str]):
+ Optional. Specifies arguments for the command that runs when the
+ container starts. This overrides the container's `CMD`
+ env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]):
+ Optional. List of environment variables to set in the container.
+ After the container starts running, code running in the container
+ can read these environment variables. Additionally, the command
+ and args fields can reference these variables. Later entries in
+ this list can also reference earlier entries. For example, the
+ following example sets the variable ``VAR_2`` to have the value
+ ``foo bar``: .. code:: json [ { "name": "VAR_1", "value": "foo" },
+ { "name": "VAR_2", "value": "$(VAR_1) bar" } ] If you switch the
+ order of the variables in the example, then the expansion does not
+ occur. This field corresponds to the ``env`` field of the
+ Kubernetes Containers `v1 core API.
+ ports (Sequence[google.cloud.aiplatform_v1.types.Port]):
+ Optional. List of ports to expose from the container. Vertex AI
+ sends any prediction requests that it receives to the first port on
+ this list. Vertex AI also sends `liveness and health checks.
+ predict_route (str):
+ Optional. HTTP path on the container to send prediction requests
+ to. Vertex AI forwards requests sent using
+ projects.locations.endpoints.predict to this path on the
+ container's IP address and port. Vertex AI then returns the
+ container's response in the API response. For example, if you set
+ this field to ``/foo``, then when Vertex AI receives a prediction
+ request, it forwards the request body in a POST request to the
+ ``/foo`` path on the port of your container specified by the first
+ value of this ``ModelContainerSpec``'s ports field. If you don't
+ specify this field, it defaults to the following value when you
+ deploy this Model to an Endpoint
+ /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict
+ The placeholders in this value are replaced as follows:
+ - ENDPOINT: The last segment (following ``endpoints/``)of the
+ Endpoint.name][] field of the Endpoint where this Model has
+ been deployed. (Vertex AI makes this value available to your
+ container code as the ```AIP_ENDPOINT_ID`` environment variable
+ health_route (str):
+ Optional. HTTP path on the container to send health checks to.
+ Vertex AI intermittently sends GET requests to this path on the
+ container's IP address and port to check that the container is
+ healthy. Read more about `health checks
+ display_name (str):
+ """
+
+ image_uri: str
+ command: Optional[List[str]] = None
+ args: Optional[List[str]] = None
+ env: Optional[List[Dict[str, str]]] = None
+ ports: Optional[List[int]] = None
+ predict_route: Optional[str] = None
+ health_route: Optional[str] = None
+
+ def to_dict(self):
+ """ML metadata schema dictionary representation of this DataClass"""
+ results = {}
+ results["imageUri"] = self.image_uri
+ if self.command:
+ results["command"] = self.command
+ if self.args:
+ results["args"] = self.args
+ if self.env:
+ results["env"] = self.env
+ if self.ports:
+ results["ports"] = self.ports
+ if self.predict_route:
+ results["predictRoute"] = self.predict_route
+ if self.health_route:
+ results["healthRoute"] = self.health_route
+
+ return results
+
+
+def create_uri_from_resource_name(resource_name: str) -> bool:
+ """Construct the service URI for a given resource_name.
+ Args:
+ resource_name (str):
+ The name of the Vertex resource, in a form of
+ projects/{project}/locations/{location}/{resource_type}/{resource_id}
+ Returns:
+ The resource URI in the form of:
+ https://{service-endpoint}/v1/{resource_name},
+ where {service-endpoint} is one of the supported service endpoints at
+ https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints
+ Raises:
+ ValueError: If resource_name does not match the specified format.
+ """
+ # TODO: support nested resource names such as models/123/evaluations/456
+ match_results = re.match(
+ r"^projects\/[A-Za-z0-9-]*\/locations\/([A-Za-z0-9-]*)\/[A-Za-z0-9-]*\/[A-Za-z0-9-]*$",
+ resource_name,
+ )
+ if not match_results:
+ raise ValueError(f"Invalid resource_name format for {resource_name}.")
+
+ location = match_results.group(1)
+ return f"https://{location}-aiplatform.googleapis.com/v1/{resource_name}"
diff --git a/google/cloud/aiplatform/metadata/utils.py b/google/cloud/aiplatform/metadata/utils.py
new file mode 100644
index 0000000000..57eb4c0691
--- /dev/null
+++ b/google/cloud/aiplatform/metadata/utils.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import List, Optional, Union
+
+
+def _make_filter_string(
+ schema_title: Optional[Union[str, List[str]]] = None,
+ in_context: Optional[List[str]] = None,
+ parent_contexts: Optional[List[str]] = None,
+ uri: Optional[str] = None,
+) -> str:
+ """Helper method to format filter strings for Metadata querying.
+
+ No enforcement of correctness.
+
+ Args:
+ schema_title (Union[str, List[str]]): Optional. schema_titles to filter for.
+ in_context (List[str]):
+ Optional. Context resource names that the node should be in. Only for Artifacts/Executions.
+ parent_contexts (List[str]): Optional. Parent contexts the context should be in. Only for Contexts.
+ uri (str): Optional. uri to match for. Only for Artifacts.
+ Returns:
+ String that can be used for Metadata service filtering.
+ """
+ parts = []
+ if schema_title:
+ if isinstance(schema_title, str):
+ parts.append(f'schema_title="{schema_title}"')
+ else:
+ substring = " OR ".join(f'schema_title="{s}"' for s in schema_title)
+ parts.append(f"({substring})")
+ if in_context:
+ for context in in_context:
+ parts.append(f'in_context("{context}")')
+ if parent_contexts:
+ parent_context_str = ",".join([f'"{c}"' for c in parent_contexts])
+ parts.append(f"parent_contexts:{parent_context_str}")
+ if uri:
+ parts.append(f'uri="{uri}"')
+ return " AND ".join(parts)
diff --git a/google/cloud/aiplatform/model_evaluation/__init__.py b/google/cloud/aiplatform/model_evaluation/__init__.py
new file mode 100644
index 0000000000..7dcbee2db5
--- /dev/null
+++ b/google/cloud/aiplatform/model_evaluation/__init__.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform.model_evaluation.model_evaluation import ModelEvaluation
+
+__all__ = ("ModelEvaluation",)
diff --git a/google/cloud/aiplatform/model_evaluation/model_evaluation.py b/google/cloud/aiplatform/model_evaluation/model_evaluation.py
new file mode 100644
index 0000000000..f8553b7644
--- /dev/null
+++ b/google/cloud/aiplatform/model_evaluation/model_evaluation.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.auth import credentials as auth_credentials
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform import models
+from google.protobuf import struct_pb2
+
+from typing import Optional
+
+
+class ModelEvaluation(base.VertexAiResourceNounWithFutureManager):
+
+ client_class = utils.ModelClientWithOverride
+ _resource_noun = "evaluations"
+ _delete_method = None
+ _getter_method = "get_model_evaluation"
+ _list_method = "list_model_evaluations"
+ _parse_resource_name_method = "parse_model_evaluation_path"
+ _format_resource_name_method = "model_evaluation_path"
+
+ @property
+ def metrics(self) -> Optional[struct_pb2.Value]:
+ """Gets the evaluation metrics from the Model Evaluation.
+ Returns:
+ A dict with model metrics created from the Model Evaluation or
+ None if the metrics for this evaluation are empty.
+ """
+ return self._gca_resource.metrics
+
+ def __init__(
+ self,
+ evaluation_name: str,
+ model_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves the ModelEvaluation resource and instantiates its representation.
+
+ Args:
+ evaluation_name (str):
+ Required. A fully-qualified model evaluation resource name or evaluation ID.
+ Example: "projects/123/locations/us-central1/models/456/evaluations/789" or
+ "789". If passing only the evaluation ID, model_id must be provided.
+ model_id (str):
+ Optional. The ID of the model to retrieve this evaluation from. If passing
+ only the evaluation ID as evaluation_name, model_id must be provided.
+ project (str):
+ Optional project to retrieve model evaluation from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional location to retrieve model evaluation from. If not set, location
+ set in aiplatform.init will be used.
+ credentials: Optional[auth_credentials.Credentials]=None,
+ Custom credentials to use to retrieve this model evaluation. If not set,
+ credentials set in aiplatform.init will be used.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=evaluation_name,
+ )
+
+ self._gca_resource = self._get_gca_resource(
+ resource_name=evaluation_name,
+ parent_resource_name_fields={models.Model._resource_noun: model_id}
+ if model_id
+ else model_id,
+ )
+
+ def delete(self):
+ raise NotImplementedError(
+ "Deleting a model evaluation has not been implemented yet."
+ )
diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py
index b93f569eaa..93243b5678 100644
--- a/google/cloud/aiplatform/models.py
+++ b/google/cloud/aiplatform/models.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,44 +14,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import pathlib
import proto
+import re
+import shutil
+import tempfile
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from google.api_core import operation
+from google.api_core import exceptions as api_exceptions
from google.auth import credentials as auth_credentials
+from google.cloud import aiplatform
from google.cloud.aiplatform import base
-from google.cloud.aiplatform import compat
from google.cloud.aiplatform import explain
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import models
from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import gcs_utils
+from google.cloud.aiplatform import model_evaluation
from google.cloud.aiplatform.compat.services import endpoint_service_client
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
endpoint as gca_endpoint_compat,
- endpoint_v1 as gca_endpoint_v1,
- endpoint_v1beta1 as gca_endpoint_v1beta1,
- explanation_v1beta1 as gca_explanation_v1beta1,
+ explanation as gca_explanation_compat,
io as gca_io_compat,
machine_resources as gca_machine_resources_compat,
- machine_resources_v1beta1 as gca_machine_resources_v1beta1,
model as gca_model_compat,
model_service as gca_model_service_compat,
- model_v1beta1 as gca_model_v1beta1,
env_var as gca_env_var_compat,
- env_var_v1beta1 as gca_env_var_v1beta1,
)
-from google.protobuf import json_format
+from google.protobuf import field_mask_pb2, json_format
+_DEFAULT_MACHINE_TYPE = "n1-standard-2"
+_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
_LOGGER = base.Logger(__name__)
+_SUPPORTED_MODEL_FILE_NAMES = [
+ "model.pkl",
+ "model.joblib",
+ "model.bst",
+ "saved_model.pb",
+ "saved_model.pbtxt",
+]
+
+
class Prediction(NamedTuple):
"""Prediction class envelopes returned Model predictions and the Model id.
@@ -70,17 +83,18 @@ class Prediction(NamedTuple):
predictions: Dict[str, List]
deployed_model_id: str
- explanations: Optional[Sequence[gca_explanation_v1beta1.Explanation]] = None
+ explanations: Optional[Sequence[gca_explanation_compat.Explanation]] = None
class Endpoint(base.VertexAiResourceNounWithFutureManager):
client_class = utils.EndpointClientWithOverride
- _is_client_prediction_client = False
_resource_noun = "endpoints"
_getter_method = "get_endpoint"
_list_method = "list_endpoints"
_delete_method = "delete_endpoint"
+ _parse_resource_name_method = "parse_endpoint_path"
+ _format_resource_name_method = "endpoint_path"
def __init__(
self,
@@ -113,30 +127,94 @@ def __init__(
credentials=credentials,
resource_name=endpoint_name,
)
- self._gca_resource = self._get_gca_resource(resource_name=endpoint_name)
+
+ endpoint_name = utils.full_resource_name(
+ resource_name=endpoint_name,
+ resource_noun="endpoints",
+ parse_resource_name_method=self._parse_resource_name,
+ format_resource_name_method=self._format_resource_name,
+ project=project,
+ location=location,
+ )
+
+ # Lazy load the Endpoint gca_resource until needed
+ self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)
+
self._prediction_client = self._instantiate_prediction_client(
- location=location or initializer.global_config.location,
+ location=self.location,
credentials=credentials,
)
+ def _skipped_getter_call(self) -> bool:
+ """Check if GAPIC resource was populated by call to get/list API methods
+
+ Returns False if `_gca_resource` is None or fully populated. Returns True
+ if `_gca_resource` is partially populated
+ """
+ return self._gca_resource and not self._gca_resource.create_time
+
+ def _sync_gca_resource_if_skipped(self) -> None:
+ """Sync GAPIC service representation of Endpoint class resource only if
+ get_endpoint() was never called."""
+ if self._skipped_getter_call():
+ self._gca_resource = self._get_gca_resource(
+ resource_name=self._gca_resource.name
+ )
+
+ def _assert_gca_resource_is_available(self) -> None:
+ """Ensures Endpoint getter was called at least once before
+ asserting on gca_resource's availability."""
+ super()._assert_gca_resource_is_available()
+ self._sync_gca_resource_if_skipped()
+
+ @property
+ def traffic_split(self) -> Dict[str, int]:
+ """A map from a DeployedModel's ID to the percentage of this Endpoint's
+ traffic that should be forwarded to that DeployedModel.
+
+ If a DeployedModel's ID is not listed in this map, then it receives no traffic.
+
+ The traffic percentage values must add up to 100, or map must be empty if
+ the Endpoint is to not accept any traffic at a moment.
+ """
+ self._sync_gca_resource()
+ return dict(self._gca_resource.traffic_split)
+
+ @property
+ def network(self) -> Optional[str]:
+ """The full name of the Google Compute Engine
+ [network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
+ Endpoint should be peered.
+
+ Takes the format `projects/{project}/global/networks/{network}`. Where
+ {project} is a project number, as in `12345`, and {network} is a network name.
+
+ Private services access must already be configured for the network. If left
+ unspecified, the Endpoint is not peered with any network.
+ """
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "network", None)
+
@classmethod
def create(
cls,
- display_name: str,
+ display_name: Optional[str] = None,
description: Optional[str] = None,
- labels: Optional[Dict] = None,
+ labels: Optional[Dict[str, str]] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
+ endpoint_id: Optional[str] = None,
) -> "Endpoint":
"""Creates a new endpoint.
Args:
display_name (str):
- Required. The user-defined name of the Endpoint.
+ Optional. The user-defined name of the Endpoint.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
project (str):
@@ -147,7 +225,7 @@ def create(
set in aiplatform.init will be used.
description (str):
Optional. The description of the Endpoint.
- labels (Dict):
+ labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Endpoints.
Label keys and values can be no longer than 64
@@ -178,6 +256,19 @@ def create(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ endpoint_id (str):
+ Optional. The ID to use for endpoint, which will become
+ the final component of the endpoint resource name. If
+ not provided, Vertex AI will generate a value for this
+ ID.
+
+ This value should be 1-10 characters, and valid
+ characters are /[0-9]/. When using HTTP/JSON, this field
+ is populated based on a query string argument, such as
+ ``?endpoint_id=12345``. This is the fallback for fields
+ that are not included in either the URI or the body.
Returns:
endpoint (endpoint.Endpoint):
Created endpoint.
@@ -185,7 +276,12 @@ def create(
api_client = cls._instantiate_client(location=location, credentials=credentials)
+ if not display_name:
+ display_name = cls._generate_display_name()
+
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
project = project or initializer.global_config.project
location = location or initializer.global_config.location
@@ -203,6 +299,8 @@ def create(
encryption_spec_key_name=encryption_spec_key_name
),
sync=sync,
+ create_request_timeout=create_request_timeout,
+ endpoint_id=endpoint_id,
)
@classmethod
@@ -214,11 +312,13 @@ def _create(
project: str,
location: str,
description: Optional[str] = None,
- labels: Optional[Dict] = None,
+ labels: Optional[Dict[str, str]] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec: Optional[gca_encryption_spec.EncryptionSpec] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
+ endpoint_id: Optional[str] = None,
) -> "Endpoint":
"""Creates a new endpoint by calling the API client.
@@ -238,7 +338,7 @@ def _create(
set in aiplatform.init will be used.
description (str):
Optional. The description of the Endpoint.
- labels (Dict):
+ labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize your Endpoints.
Label keys and values can be no longer than 64
@@ -262,6 +362,19 @@ def _create(
If set, this Dataset and all sub-resources of this Dataset will be secured by this key.
sync (bool):
Whether to create this endpoint synchronously.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ endpoint_id (str):
+ Optional. The ID to use for endpoint, which will become
+ the final component of the endpoint resource name. If
+ not provided, Vertex AI will generate a value for this
+ ID.
+
+ This value should be 1-10 characters, and valid
+ characters are /[0-9]/. When using HTTP/JSON, this field
+ is populated based on a query string argument, such as
+ ``?endpoint_id=12345``. This is the fallback for fields
+ that are not included in either the URI or the body.
Returns:
endpoint (endpoint.Endpoint):
Created endpoint.
@@ -279,7 +392,11 @@ def _create(
)
operation_future = api_client.create_endpoint(
- parent=parent, endpoint=gapic_endpoint, metadata=metadata
+ parent=parent,
+ endpoint=gapic_endpoint,
+ endpoint_id=endpoint_id,
+ metadata=metadata,
+ timeout=create_request_timeout,
)
_LOGGER.log_create_with_lro(cls, operation_future)
@@ -288,16 +405,58 @@ def _create(
_LOGGER.log_create_complete(cls, created_endpoint, "endpoint")
- return cls(
- endpoint_name=created_endpoint.name,
+ return cls._construct_sdk_resource_from_gapic(
+ gapic_resource=created_endpoint,
project=project,
location=location,
credentials=credentials,
)
+ @classmethod
+ def _construct_sdk_resource_from_gapic(
+ cls,
+ gapic_resource: proto.Message,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "Endpoint":
+ """Given a GAPIC Endpoint object, return the SDK representation.
+
+ Args:
+ gapic_resource (proto.Message):
+ A GAPIC representation of a Endpoint resource, usually
+ retrieved by a get_* or in a list_* API call.
+ project (str):
+ Optional. Project to construct Endpoint object from. If not set,
+ project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to construct Endpoint object from. If not set,
+ location set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to construct Endpoint.
+ Overrides credentials set in aiplatform.init.
+
+ Returns:
+ Endpoint:
+ An initialized Endpoint resource.
+ """
+ endpoint = cls._empty_constructor(
+ project=project, location=location, credentials=credentials
+ )
+
+ endpoint._gca_resource = gapic_resource
+
+ endpoint._prediction_client = cls._instantiate_prediction_client(
+ location=endpoint.location,
+ credentials=credentials,
+ )
+
+ return endpoint
+
@staticmethod
def _allocate_traffic(
- traffic_split: Dict[str, int], traffic_percentage: int,
+ traffic_split: Dict[str, int],
+ traffic_percentage: int,
) -> Dict[str, int]:
"""Allocates desired traffic to new deployed model and scales traffic
of older deployed models.
@@ -327,13 +486,14 @@ def _allocate_traffic(
new_traffic_split[deployed_model] += 1
unallocated_traffic -= 1
- new_traffic_split["0"] = traffic_percentage
+ new_traffic_split[_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY] = traffic_percentage
return new_traffic_split
@staticmethod
def _unallocate_traffic(
- traffic_split: Dict[str, int], deployed_model_id: str,
+ traffic_split: Dict[str, int],
+ deployed_model_id: str,
) -> Dict[str, int]:
"""Sets deployed model id's traffic to 0 and scales the traffic of
other deployed models.
@@ -452,7 +612,6 @@ def _validate_deploy_args(
raise ValueError("Traffic percentage cannot be negative.")
elif traffic_split:
- # TODO(b/172678233) verify every referenced deployed model exists
if sum(traffic_split.values()) != 100:
raise ValueError(
"Sum of all traffic within traffic split needs to be 100."
@@ -483,6 +642,9 @@ def deploy(
explanation_parameters: Optional[explain.ExplanationParameters] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync=True,
+ deploy_request_timeout: Optional[float] = None,
+ autoscaling_target_cpu_utilization: Optional[int] = None,
+ autoscaling_target_accelerator_duty_cycle: Optional[int] = None,
) -> None:
"""Deploys a Model to the Endpoint.
@@ -554,7 +716,17 @@ def deploy(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ deploy_request_timeout (float):
+ Optional. The timeout for the deploy request in seconds.
+ autoscaling_target_cpu_utilization (int):
+ Target CPU Utilization to use for Autoscaling Replicas.
+ A default value of 60 will be used if not specified.
+ autoscaling_target_accelerator_duty_cycle (int):
+ Target Accelerator Duty Cycle.
+ Must also set accelerator_type and accelerator_count if specified.
+ A default value of 60 will be used if not specified.
"""
+ self._sync_gca_resource_if_skipped()
self._validate_deploy_args(
min_replica_count,
@@ -582,6 +754,9 @@ def deploy(
explanation_parameters=explanation_parameters,
metadata=metadata,
sync=sync,
+ deploy_request_timeout=deploy_request_timeout,
+ autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,
+ autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle,
)
@base.optional_sync()
@@ -601,6 +776,9 @@ def _deploy(
explanation_parameters: Optional[explain.ExplanationParameters] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
sync=True,
+ deploy_request_timeout: Optional[float] = None,
+ autoscaling_target_cpu_utilization: Optional[int] = None,
+ autoscaling_target_accelerator_duty_cycle: Optional[int] = None,
) -> None:
"""Deploys a Model to the Endpoint.
@@ -672,8 +850,17 @@ def _deploy(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ deploy_request_timeout (float):
+ Optional. The timeout for the deploy request in seconds.
+ autoscaling_target_cpu_utilization (int):
+ Target CPU Utilization to use for Autoscaling Replicas.
+ A default value of 60 will be used if not specified.
+ autoscaling_target_accelerator_duty_cycle (int):
+ Target Accelerator Duty Cycle.
+ Must also set accelerator_type and accelerator_count if specified.
+ A default value of 60 will be used if not specified.
Raises:
- ValueError if there is not current traffic split and traffic percentage
+ ValueError: If there is not current traffic split and traffic percentage
is not 0 or 100.
"""
_LOGGER.log_action_start_against_resource(
@@ -683,7 +870,7 @@ def _deploy(
self._deploy_call(
self.api_client,
self.resource_name,
- model.resource_name,
+ model,
self._gca_resource.traffic_split,
deployed_model_display_name=deployed_model_display_name,
traffic_percentage=traffic_percentage,
@@ -697,6 +884,9 @@ def _deploy(
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
metadata=metadata,
+ deploy_request_timeout=deploy_request_timeout,
+ autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,
+ autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle,
)
_LOGGER.log_action_completed_against_resource("model", "deployed", self)
@@ -708,7 +898,7 @@ def _deploy_call(
cls,
api_client: endpoint_service_client.EndpointServiceClient,
endpoint_resource_name: str,
- model_resource_name: str,
+ model: "Model",
endpoint_resource_traffic_split: Optional[proto.MapField] = None,
deployed_model_display_name: Optional[str] = None,
traffic_percentage: Optional[int] = 0,
@@ -722,6 +912,9 @@ def _deploy_call(
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ deploy_request_timeout: Optional[float] = None,
+ autoscaling_target_cpu_utilization: Optional[int] = None,
+ autoscaling_target_accelerator_duty_cycle: Optional[int] = None,
):
"""Helper method to deploy model to endpoint.
@@ -730,8 +923,8 @@ def _deploy_call(
Required. endpoint_service_client.EndpointServiceClient to make call.
endpoint_resource_name (str):
Required. Endpoint resource name to deploy model to.
- model_resource_name (str):
- Required. Model resource name of Model to deploy.
+ model (aiplatform.Model):
+ Required. Model to be deployed.
endpoint_resource_traffic_split (proto.MapField):
Optional. Endpoint current resource traffic split.
deployed_model_display_name (str):
@@ -793,11 +986,21 @@ def _deploy_call(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ deploy_request_timeout (float):
+ Optional. The timeout for the deploy request in seconds.
+ autoscaling_target_cpu_utilization (int):
+ Optional. Target CPU Utilization to use for Autoscaling Replicas.
+ A default value of 60 will be used if not specified.
+ autoscaling_target_accelerator_duty_cycle (int):
+ Optional. Target Accelerator Duty Cycle.
+ Must also set accelerator_type and accelerator_count if specified.
+ A default value of 60 will be used if not specified.
Raises:
ValueError: If there is not current traffic split and traffic percentage
is not 0 or 100.
ValueError: If only `explanation_metadata` or `explanation_parameters`
is specified.
+ ValueError: If model does not support deployment.
"""
max_replica_count = max(min_replica_count, max_replica_count)
@@ -807,42 +1010,107 @@ def _deploy_call(
"Both `accelerator_type` and `accelerator_count` should be specified or None."
)
- gca_endpoint = gca_endpoint_compat
- gca_machine_resources = gca_machine_resources_compat
- if explanation_metadata and explanation_parameters:
- gca_endpoint = gca_endpoint_v1beta1
- gca_machine_resources = gca_machine_resources_v1beta1
+ if autoscaling_target_accelerator_duty_cycle is not None and (
+ not accelerator_type or not accelerator_count
+ ):
+ raise ValueError(
+ "Both `accelerator_type` and `accelerator_count` should be set "
+ "when specifying autoscaling_target_accelerator_duty_cycle`"
+ )
- deployed_model = gca_endpoint.DeployedModel(
- model=model_resource_name,
+ deployed_model = gca_endpoint_compat.DeployedModel(
+ model=model.resource_name,
display_name=deployed_model_display_name,
service_account=service_account,
)
- if machine_type:
- machine_spec = gca_machine_resources.MachineSpec(machine_type=machine_type)
+ supports_automatic_resources = (
+ aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
+ in model.supported_deployment_resources_types
+ )
+ supports_dedicated_resources = (
+ aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
+ in model.supported_deployment_resources_types
+ )
+ provided_custom_machine_spec = (
+ machine_type
+ or accelerator_type
+ or accelerator_count
+ or autoscaling_target_accelerator_duty_cycle
+ or autoscaling_target_cpu_utilization
+ )
+
+ # If the model supports both automatic and dedicated deployment resources,
+ # decide based on the presence of machine spec customizations
+ use_dedicated_resources = supports_dedicated_resources and (
+ not supports_automatic_resources or provided_custom_machine_spec
+ )
+
+ if provided_custom_machine_spec and not use_dedicated_resources:
+ _LOGGER.info(
+ "Model does not support dedicated deployment resources. "
+ "The machine_type, accelerator_type and accelerator_count,"
+ "autoscaling_target_accelerator_duty_cycle,"
+ "autoscaling_target_cpu_utilization parameters are ignored."
+ )
+
+ if use_dedicated_resources and not machine_type:
+ machine_type = _DEFAULT_MACHINE_TYPE
+ _LOGGER.info(f"Using default machine_type: {machine_type}")
+
+ if use_dedicated_resources:
+
+ dedicated_resources = gca_machine_resources_compat.DedicatedResources(
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ )
+
+ machine_spec = gca_machine_resources_compat.MachineSpec(
+ machine_type=machine_type
+ )
+
+ if autoscaling_target_cpu_utilization:
+ autoscaling_metric_spec = gca_machine_resources_compat.AutoscalingMetricSpec(
+ metric_name="aiplatform.googleapis.com/prediction/online/cpu/utilization",
+ target=autoscaling_target_cpu_utilization,
+ )
+ dedicated_resources.autoscaling_metric_specs.extend(
+ [autoscaling_metric_spec]
+ )
if accelerator_type and accelerator_count:
utils.validate_accelerator_type(accelerator_type)
machine_spec.accelerator_type = accelerator_type
machine_spec.accelerator_count = accelerator_count
- deployed_model.dedicated_resources = gca_machine_resources.DedicatedResources(
- machine_spec=machine_spec,
- min_replica_count=min_replica_count,
- max_replica_count=max_replica_count,
- )
+ if autoscaling_target_accelerator_duty_cycle:
+ autoscaling_metric_spec = gca_machine_resources_compat.AutoscalingMetricSpec(
+ metric_name="aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle",
+ target=autoscaling_target_accelerator_duty_cycle,
+ )
+ dedicated_resources.autoscaling_metric_specs.extend(
+ [autoscaling_metric_spec]
+ )
+
+ dedicated_resources.machine_spec = machine_spec
+ deployed_model.dedicated_resources = dedicated_resources
+ elif supports_automatic_resources:
+ deployed_model.automatic_resources = (
+ gca_machine_resources_compat.AutomaticResources(
+ min_replica_count=min_replica_count,
+ max_replica_count=max_replica_count,
+ )
+ )
else:
- deployed_model.automatic_resources = gca_machine_resources.AutomaticResources(
- min_replica_count=min_replica_count,
- max_replica_count=max_replica_count,
+ raise ValueError(
+ "Model does not support deployment. "
+ "See https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#google.cloud.aiplatform.v1.Model.FIELDS.repeated.google.cloud.aiplatform.v1.Model.DeploymentResourcesType.google.cloud.aiplatform.v1.Model.supported_deployment_resources_types"
)
# Service will throw error if both metadata and parameters are not provided
if explanation_metadata and explanation_parameters:
- api_client = api_client.select_version(compat.V1BETA1)
- explanation_spec = gca_endpoint.explanation.ExplanationSpec()
+ explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
explanation_spec.metadata = explanation_metadata
explanation_spec.parameters = explanation_parameters
deployed_model.explanation_spec = explanation_spec
@@ -869,6 +1137,7 @@ def _deploy_call(
deployed_model=deployed_model,
traffic_split=traffic_split,
metadata=metadata,
+ timeout=deploy_request_timeout,
)
_LOGGER.log_action_started_against_resource_with_lro(
@@ -886,34 +1155,49 @@ def undeploy(
) -> None:
"""Undeploys a deployed model.
- Proportionally adjusts the traffic_split among the remaining deployed
- models of the endpoint.
+ The model to be undeployed should have no traffic or user must provide
+ a new traffic_split with the remaining deployed models. Refer
+ to `Endpoint.traffic_split` for the current traffic split mapping.
Args:
deployed_model_id (str):
Required. The ID of the DeployedModel to be undeployed from the
Endpoint.
traffic_split (Dict[str, int]):
- Optional. A map from a DeployedModel's ID to the percentage of
+ Optional. A map of DeployedModel IDs to the percentage of
this Endpoint's traffic that should be forwarded to that DeployedModel.
- If a DeployedModel's ID is not listed in this map, then it receives
- no traffic. The traffic percentage values must add up to 100, or
- map must be empty if the Endpoint is to not accept any traffic at
- the moment. Key for model being deployed is "0". Should not be
- provided if traffic_percentage is provided.
+ Required if undeploying a model with non-zero traffic from an Endpoint
+ with multiple deployed models. The traffic percentage values must add
+ up to 100, or map must be empty if the Endpoint is to not accept any traffic
+ at the moment. If a DeployedModel's ID is not listed in this map, then it
+ receives no traffic.
metadata (Sequence[Tuple[str, str]]):
Optional. Strings which should be sent along with the request as
metadata.
"""
+ self._sync_gca_resource_if_skipped()
+
if traffic_split is not None:
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
if sum(traffic_split.values()) != 100:
- # TODO(b/172678233) verify every referenced deployed model exists
raise ValueError(
"Sum of all traffic within traffic split needs to be 100."
)
+ # Two or more models deployed to Endpoint and remaining traffic will be zero
+ elif (
+ len(self.traffic_split) > 1
+ and deployed_model_id in self._gca_resource.traffic_split
+ and self._gca_resource.traffic_split[deployed_model_id] == 100
+ ):
+ raise ValueError(
+ f"Undeploying deployed model '{deployed_model_id}' would leave the remaining "
+ "traffic split at 0%. Traffic split must add up to 100% when models are "
+ "deployed. Please undeploy the other models first or provide an updated "
+ "traffic_split."
+ )
+
self._undeploy(
deployed_model_id=deployed_model_id,
traffic_split=traffic_split,
@@ -950,6 +1234,7 @@ def _undeploy(
Optional. Strings which should be sent along with the request as
metadata.
"""
+ self._sync_gca_resource_if_skipped()
current_traffic_split = traffic_split or dict(self._gca_resource.traffic_split)
if deployed_model_id in current_traffic_split:
@@ -996,7 +1281,7 @@ def _instantiate_prediction_client(
the prediction client.
Returns:
prediction_client (prediction_service_client.PredictionServiceClient):
- Initalized prediction client with optional overrides.
+ Initialized prediction client with optional overrides.
"""
return initializer.global_config.create_client(
client_class=utils.PredictionClientWithOverride,
@@ -1005,7 +1290,110 @@ def _instantiate_prediction_client(
prediction_client=True,
)
- def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
+ def update(
+ self,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ traffic_split: Optional[Dict[str, int]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ update_request_timeout: Optional[float] = None,
+ ) -> "Endpoint":
+ """Updates an endpoint.
+
+ Example usage:
+
+ my_endpoint = my_endpoint.update(
+ display_name='my-updated-endpoint',
+ description='my updated description',
+ labels={'key': 'value'},
+ traffic_split={
+ '123456': 20,
+ '234567': 80,
+ },
+ )
+
+ Args:
+ display_name (str):
+ Optional. The display name of the Endpoint.
+ The name can be up to 128 characters long and can be consist of any UTF-8
+ characters.
+ description (str):
+ Optional. The description of the Endpoint.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to organize your Endpoints.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ traffic_split (Dict[str, int]):
+ Optional. A map from a DeployedModel's ID to the percentage of this Endpoint's
+ traffic that should be forwarded to that DeployedModel.
+ If a DeployedModel's ID is not listed in this map, then it receives no traffic.
+ The traffic percentage values must add up to 100, or map must be empty if
+ the Endpoint is to not accept any traffic at a moment.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ update_request_timeout (float):
+ Optional. The timeout for the update request in seconds.
+
+ Returns:
+ Endpoint - Updated endpoint resource.
+
+ Raises:
+ ValueError: If `labels` is not the correct format.
+ """
+
+ self.wait()
+
+ current_endpoint_proto = self.gca_resource
+ copied_endpoint_proto = current_endpoint_proto.__class__(current_endpoint_proto)
+
+ update_mask: List[str] = []
+
+ if display_name:
+ utils.validate_display_name(display_name)
+ copied_endpoint_proto.display_name = display_name
+ update_mask.append("display_name")
+
+ if description:
+ copied_endpoint_proto.description = description
+ update_mask.append("description")
+
+ if labels:
+ utils.validate_labels(labels)
+ copied_endpoint_proto.labels = labels
+ update_mask.append("labels")
+
+ if traffic_split:
+ update_mask.append("traffic_split")
+ copied_endpoint_proto.traffic_split = traffic_split
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "endpoint",
+ self,
+ )
+
+ self._gca_resource = self.api_client.update_endpoint(
+ endpoint=copied_endpoint_proto,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ timeout=update_request_timeout,
+ )
+
+ _LOGGER.log_action_completed_against_resource("endpoint", "updated", self)
+
+ return self
+
+ def predict(
+ self,
+ instances: List,
+ parameters: Optional[Dict] = None,
+ timeout: Optional[float] = None,
+ ) -> Prediction:
"""Make a prediction against this Endpoint.
Args:
@@ -1028,13 +1416,17 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
+ timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction: Prediction with returned predictions and Model Id.
"""
self.wait()
prediction_response = self._prediction_client.predict(
- endpoint=self.resource_name, instances=instances, parameters=parameters
+ endpoint=self._gca_resource.name,
+ instances=instances,
+ parameters=parameters,
+ timeout=timeout,
)
return Prediction(
@@ -1050,6 +1442,7 @@ def explain(
instances: List[Dict],
parameters: Optional[Dict] = None,
deployed_model_id: Optional[str] = None,
+ timeout: Optional[float] = None,
) -> Prediction:
"""Make a prediction with explanations against this Endpoint.
@@ -1080,18 +1473,18 @@ def explain(
deployed_model_id (str):
Optional. If specified, this ExplainRequest will be served by the
chosen DeployedModel, overriding this Endpoint's traffic split.
+ timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction: Prediction with returned predictions, explanations and Model Id.
"""
self.wait()
- explain_response = self._prediction_client.select_version(
- compat.V1BETA1
- ).explain(
+ explain_response = self._prediction_client.explain(
endpoint=self.resource_name,
instances=instances,
parameters=parameters,
deployed_model_id=deployed_model_id,
+ timeout=timeout,
)
return Prediction(
@@ -1150,19 +1543,15 @@ def list(
credentials=credentials,
)
- def list_models(
- self,
- ) -> Sequence[
- Union[gca_endpoint_v1.DeployedModel, gca_endpoint_v1beta1.DeployedModel]
- ]:
+ def list_models(self) -> List[gca_endpoint_compat.DeployedModel]:
"""Returns a list of the models deployed to this Endpoint.
Returns:
- deployed_models (Sequence[aiplatform.gapic.DeployedModel]):
+ deployed_models (List[aiplatform.gapic.DeployedModel]):
A list of the models deployed in this Endpoint.
"""
self._sync_gca_resource()
- return self._gca_resource.deployed_models
+ return list(self._gca_resource.deployed_models)
def undeploy_all(self, sync: bool = True) -> "Endpoint":
"""Undeploys every model deployed to this Endpoint.
@@ -1175,8 +1564,13 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint":
"""
self._sync_gca_resource()
- for deployed_model in self._gca_resource.deployed_models:
- self._undeploy(deployed_model_id=deployed_model.id, sync=sync)
+ models_to_undeploy = sorted( # Undeploy zero traffic models first
+ self._gca_resource.traffic_split.keys(),
+ key=lambda id: self._gca_resource.traffic_split[id],
+ )
+
+ for deployed_model in models_to_undeploy:
+ self._undeploy(deployed_model_id=deployed_model, sync=sync)
return self
@@ -1204,20 +1598,24 @@ def delete(self, force: bool = False, sync: bool = True) -> None:
class Model(base.VertexAiResourceNounWithFutureManager):
client_class = utils.ModelClientWithOverride
- _is_client_prediction_client = False
_resource_noun = "models"
_getter_method = "get_model"
_list_method = "list_models"
_delete_method = "delete_model"
+ _parse_resource_name_method = "parse_model_path"
+ _format_resource_name_method = "model_path"
@property
- def uri(self):
- """Uri of the model."""
- return self._gca_resource.artifact_uri
+ def uri(self) -> Optional[str]:
+ """Path to the directory containing the Model artifact and any of its
+ supporting files. Not present for AutoML Models."""
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.artifact_uri or None
@property
- def description(self):
+ def description(self) -> str:
"""Description of the model."""
+ self._assert_gca_resource_is_available()
return self._gca_resource.description
@property
@@ -1232,6 +1630,7 @@ def supported_export_formats(
{'tf-saved-model': []}
"""
+ self._assert_gca_resource_is_available()
return {
export_format.id: [
gca_model_compat.Model.ExportFormat.ExportableContent(content)
@@ -1240,6 +1639,104 @@ def supported_export_formats(
for export_format in self._gca_resource.supported_export_formats
}
+ @property
+ def supported_deployment_resources_types(
+ self,
+ ) -> List[aiplatform.gapic.Model.DeploymentResourcesType]:
+ """List of deployment resource types accepted for this Model.
+
+ When this Model is deployed, its prediction resources are described by
+ the `prediction_resources` field of the objects returned by
+ `Endpoint.list_models()`. Because not all Models support all resource
+ configuration types, the configuration types this Model supports are
+ listed here.
+
+ If no configuration types are listed, the Model cannot be
+ deployed to an `Endpoint` and does not support online predictions
+ (`Endpoint.predict()` or `Endpoint.explain()`). Such a Model can serve
+ predictions by using a `BatchPredictionJob`, if it has at least one entry
+ each in `Model.supported_input_storage_formats` and
+ `Model.supported_output_storage_formats`."""
+ self._assert_gca_resource_is_available()
+ return list(self._gca_resource.supported_deployment_resources_types)
+
+ @property
+ def supported_input_storage_formats(self) -> List[str]:
+ """The formats this Model supports in the `input_config` field of a
+ `BatchPredictionJob`. If `Model.predict_schemata.instance_schema_uri`
+ exists, the instances should be given as per that schema.
+
+ [Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions#batch_request_input)
+
+ If this Model doesn't support any of these formats it means it cannot be
+ used with a `BatchPredictionJob`. However, if it has
+ `supported_deployment_resources_types`, it could serve online predictions
+ by using `Endpoint.predict()` or `Endpoint.explain()`.
+ """
+ self._assert_gca_resource_is_available()
+ return list(self._gca_resource.supported_input_storage_formats)
+
+ @property
+ def supported_output_storage_formats(self) -> List[str]:
+ """The formats this Model supports in the `output_config` field of a
+ `BatchPredictionJob`.
+
+ If both `Model.predict_schemata.instance_schema_uri` and
+ `Model.predict_schemata.prediction_schema_uri` exist, the predictions
+ are returned together with their instances. In other words, the
+ prediction has the original instance data first, followed by the actual
+ prediction content (as per the schema).
+
+ [Read the docs for more on batch prediction formats](https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions)
+
+ If this Model doesn't support any of these formats it means it cannot be
+ used with a `BatchPredictionJob`. However, if it has
+ `supported_deployment_resources_types`, it could serve online predictions
+ by using `Endpoint.predict()` or `Endpoint.explain()`.
+ """
+ self._assert_gca_resource_is_available()
+ return list(self._gca_resource.supported_output_storage_formats)
+
+ @property
+ def predict_schemata(self) -> Optional[aiplatform.gapic.PredictSchemata]:
+ """The schemata that describe formats of the Model's predictions and
+ explanations, if available."""
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "predict_schemata")
+
+ @property
+ def training_job(self) -> Optional["aiplatform.training_jobs._TrainingJob"]:
+ """The TrainingJob that uploaded this Model, if any.
+
+ Raises:
+ api_core.exceptions.NotFound: If the Model's training job resource
+ cannot be found on the Vertex service.
+ """
+ self._assert_gca_resource_is_available()
+ job_name = getattr(self._gca_resource, "training_pipeline")
+
+ if not job_name:
+ return None
+
+ try:
+ return aiplatform.training_jobs._TrainingJob._get_and_return_subclass(
+ resource_name=job_name,
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+ except api_exceptions.NotFound:
+ raise api_exceptions.NotFound(
+ f"The training job used to create this model could not be found: {job_name}"
+ )
+
+ @property
+ def container_spec(self) -> Optional[aiplatform.gapic.ModelContainerSpec]:
+ """The specification of the container that is to be used when deploying
+ this Model. Not present for AutoML Models."""
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "container_spec")
+
def __init__(
self,
model_name: str,
@@ -1273,13 +1770,81 @@ def __init__(
)
self._gca_resource = self._get_gca_resource(resource_name=model_name)
+ def update(
+ self,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ ) -> "Model":
+ """Updates a model.
+
+ Example usage:
+
+ my_model = my_model.update(
+ display_name='my-model',
+ description='my description',
+ labels={'key': 'value'},
+ )
+
+ Args:
+ display_name (str):
+ The display name of the Model. The name can be up to 128
+ characters long and can be consist of any UTF-8 characters.
+ description (str):
+ The description of the model.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ Returns:
+ model: Updated model resource.
+ Raises:
+ ValueError: If `labels` is not the correct format.
+ """
+
+ self.wait()
+
+ current_model_proto = self.gca_resource
+ copied_model_proto = current_model_proto.__class__(current_model_proto)
+
+ update_mask: List[str] = []
+
+ if display_name:
+ utils.validate_display_name(display_name)
+
+ copied_model_proto.display_name = display_name
+ update_mask.append("display_name")
+
+ if description:
+ copied_model_proto.description = description
+ update_mask.append("description")
+
+ if labels:
+ utils.validate_labels(labels)
+
+ copied_model_proto.labels = labels
+ update_mask.append("labels")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ self.api_client.update_model(model=copied_model_proto, update_mask=update_mask)
+
+ self._sync_gca_resource()
+
+ return self
+
# TODO(b/170979552) Add support for predict schemata
# TODO(b/170979926) Add support for metadata and metadata schema
@classmethod
@base.optional_sync()
def upload(
cls,
- display_name: str,
serving_container_image_uri: str,
*,
artifact_uri: Optional[str] = None,
@@ -1295,11 +1860,15 @@ def upload(
prediction_schema_uri: Optional[str] = None,
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
+ display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
+ staging_bucket: Optional[str] = None,
sync=True,
+ upload_request_timeout: Optional[float] = None,
) -> "Model":
"""Uploads a model and returns a Model representing the uploaded Model
resource.
@@ -1314,7 +1883,7 @@ def upload(
Args:
display_name (str):
- Required. The display name of the Model. The name can be up to 128
+ Optional. The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
serving_container_image_uri (str):
Required. The URI of the Model serving container.
@@ -1421,6 +1990,16 @@ def upload(
credentials: Optional[auth_credentials.Credentials]=None,
Custom credentials to use to upload this model. Overrides credentials
set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the model. Has the
@@ -1432,42 +2011,45 @@ def upload(
If set, this Model and all sub-resources of this Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
+ staging_bucket (str):
+ Optional. Bucket to stage local model artifacts. Overrides
+ staging_bucket set in aiplatform.init.
+ upload_request_timeout (float):
+ Optional. The timeout for the upload request in seconds.
Returns:
model: Instantiated representation of the uploaded model resource.
Raises:
ValueError: If only `explanation_metadata` or `explanation_parameters`
is specified.
+ Also if model directory does not contain a supported model file.
"""
+ if not display_name:
+ display_name = cls._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
if bool(explanation_metadata) != bool(explanation_parameters):
raise ValueError(
"Both `explanation_metadata` and `explanation_parameters` should be specified or None."
)
- gca_endpoint = gca_endpoint_compat
- gca_model = gca_model_compat
- gca_env_var = gca_env_var_compat
- if explanation_metadata and explanation_parameters:
- gca_endpoint = gca_endpoint_v1beta1
- gca_model = gca_model_v1beta1
- gca_env_var = gca_env_var_v1beta1
-
api_client = cls._instantiate_client(location, credentials)
env = None
ports = None
if serving_container_environment_variables:
env = [
- gca_env_var.EnvVar(name=str(key), value=str(value))
+ gca_env_var_compat.EnvVar(name=str(key), value=str(value))
for key, value in serving_container_environment_variables.items()
]
if serving_container_ports:
ports = [
- gca_model.Port(container_port=port) for port in serving_container_ports
+ gca_model_compat.Port(container_port=port)
+ for port in serving_container_ports
]
- container_spec = gca_model.ModelContainerSpec(
+ container_spec = gca_model_compat.ModelContainerSpec(
image_uri=serving_container_image_uri,
command=serving_container_command,
args=serving_container_args,
@@ -1479,7 +2061,7 @@ def upload(
model_predict_schemata = None
if any([instance_schema_uri, parameters_schema_uri, prediction_schema_uri]):
- model_predict_schemata = gca_model.PredictSchemata(
+ model_predict_schemata = gca_model_compat.PredictSchemata(
instance_schema_uri=instance_schema_uri,
parameters_schema_uri=parameters_schema_uri,
prediction_schema_uri=prediction_schema_uri,
@@ -1490,21 +2072,51 @@ def upload(
encryption_spec_key_name=encryption_spec_key_name,
)
- managed_model = gca_model.Model(
+ managed_model = gca_model_compat.Model(
display_name=display_name,
description=description,
container_spec=container_spec,
predict_schemata=model_predict_schemata,
+ labels=labels,
encryption_spec=encryption_spec,
)
+ if artifact_uri and not artifact_uri.startswith("gs://"):
+ model_dir = pathlib.Path(artifact_uri)
+ # Validating the model directory
+ if not model_dir.exists():
+ raise ValueError(f"artifact_uri path does not exist: '{artifact_uri}'")
+ PREBUILT_IMAGE_RE = "(us|europe|asia)-docker.pkg.dev/vertex-ai/prediction/"
+ if re.match(PREBUILT_IMAGE_RE, serving_container_image_uri):
+ if not model_dir.is_dir():
+ raise ValueError(
+ f"artifact_uri path must be a directory: '{artifact_uri}' when using prebuilt image '{serving_container_image_uri}'"
+ )
+ if not any(
+ (model_dir / file_name).exists()
+ for file_name in _SUPPORTED_MODEL_FILE_NAMES
+ ):
+ raise ValueError(
+ "artifact_uri directory does not contain any supported model files. "
+ f"When using a prebuilt serving image, the upload method only supports the following model files: '{_SUPPORTED_MODEL_FILE_NAMES}'"
+ )
+
+ # Uploading the model
+ staged_data_uri = gcs_utils.stage_local_data_in_gcs(
+ data_path=str(model_dir),
+ staging_gcs_dir=staging_bucket,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ artifact_uri = staged_data_uri
+
if artifact_uri:
managed_model.artifact_uri = artifact_uri
# Override explanation_spec if both required fields are provided
if explanation_metadata and explanation_parameters:
- api_client = api_client.select_version(compat.V1BETA1)
- explanation_spec = gca_endpoint.explanation.ExplanationSpec()
+ explanation_spec = gca_endpoint_compat.explanation.ExplanationSpec()
explanation_spec.metadata = explanation_metadata
explanation_spec.parameters = explanation_parameters
managed_model.explanation_spec = explanation_spec
@@ -1512,6 +2124,7 @@ def upload(
lro = api_client.upload_model(
parent=initializer.global_config.common_location_path(project, location),
model=managed_model,
+ timeout=upload_request_timeout,
)
_LOGGER.log_create_with_lro(cls, lro)
@@ -1542,6 +2155,9 @@ def deploy(
metadata: Optional[Sequence[Tuple[str, str]]] = (),
encryption_spec_key_name: Optional[str] = None,
sync=True,
+ deploy_request_timeout: Optional[float] = None,
+ autoscaling_target_cpu_utilization: Optional[int] = None,
+ autoscaling_target_accelerator_duty_cycle: Optional[int] = None,
) -> Endpoint:
"""Deploys model to endpoint. Endpoint will be created if unspecified.
@@ -1624,6 +2240,15 @@ def deploy(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ deploy_request_timeout (float):
+ Optional. The timeout for the deploy request in seconds.
+ autoscaling_target_cpu_utilization (int):
+ Optional. Target CPU Utilization to use for Autoscaling Replicas.
+ A default value of 60 will be used if not specified.
+ autoscaling_target_accelerator_duty_cycle (int):
+ Optional. Target Accelerator Duty Cycle.
+ Must also set accelerator_type and accelerator_count if specified.
+ A default value of 60 will be used if not specified.
Returns:
endpoint ("Endpoint"):
Endpoint with the deployed model.
@@ -1657,6 +2282,9 @@ def deploy(
encryption_spec_key_name=encryption_spec_key_name
or initializer.global_config.encryption_spec_key_name,
sync=sync,
+ deploy_request_timeout=deploy_request_timeout,
+ autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,
+ autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle,
)
@base.optional_sync(return_input_arg="endpoint", bind_future_to_self=False)
@@ -1677,6 +2305,9 @@ def _deploy(
metadata: Optional[Sequence[Tuple[str, str]]] = (),
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ deploy_request_timeout: Optional[float] = None,
+ autoscaling_target_cpu_utilization: Optional[int] = None,
+ autoscaling_target_accelerator_duty_cycle: Optional[int] = None,
) -> Endpoint:
"""Deploys model to endpoint. Endpoint will be created if unspecified.
@@ -1759,6 +2390,15 @@ def _deploy(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ deploy_request_timeout (float):
+ Optional. The timeout for the deploy request in seconds.
+ autoscaling_target_cpu_utilization (int):
+ Optional. Target CPU Utilization to use for Autoscaling Replicas.
+ A default value of 60 will be used if not specified.
+ autoscaling_target_accelerator_duty_cycle (int):
+ Optional. Target Accelerator Duty Cycle.
+ Must also set accelerator_type and accelerator_count if specified.
+ A default value of 60 will be used if not specified.
Returns:
endpoint ("Endpoint"):
Endpoint with the deployed model.
@@ -1779,7 +2419,7 @@ def _deploy(
Endpoint._deploy_call(
endpoint.api_client,
endpoint.resource_name,
- self.resource_name,
+ self,
endpoint._gca_resource.traffic_split,
deployed_model_display_name=deployed_model_display_name,
traffic_percentage=traffic_percentage,
@@ -1793,6 +2433,9 @@ def _deploy(
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
metadata=metadata,
+ deploy_request_timeout=deploy_request_timeout,
+ autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization,
+ autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle,
)
_LOGGER.log_action_completed_against_resource("model", "deployed", endpoint)
@@ -1803,7 +2446,7 @@ def _deploy(
def batch_predict(
self,
- job_display_name: str,
+ job_display_name: Optional[str] = None,
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bigquery_source: Optional[str] = None,
instances_format: str = "jsonl",
@@ -1819,10 +2462,12 @@ def batch_predict(
generate_explanation: Optional[bool] = False,
explanation_metadata: Optional[explain.ExplanationMetadata] = None,
explanation_parameters: Optional[explain.ExplanationParameters] = None,
- labels: Optional[dict] = None,
+ labels: Optional[Dict[str, str]] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ batch_size: Optional[int] = None,
) -> jobs.BatchPredictionJob:
"""Creates a batch prediction job using this Model and outputs
prediction results to the provided destination prefix in the specified
@@ -1840,22 +2485,20 @@ def batch_predict(
Args:
job_display_name (str):
- Required. The user-defined name of the BatchPredictionJob.
+ Optional. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
gcs_source: Optional[Sequence[str]] = None
Google Cloud Storage URI(-s) to your instances to run
batch prediction on. They must match `instances_format`.
- May contain wildcards. For more information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source: Optional[str] = None
BigQuery URI to a table, up to 2000 characters long. For example:
- `projectId.bqDatasetId.bqTableId`
+ `bq://projectId.bqDatasetId.bqTableId`
instances_format: str = "jsonl"
- Required. The format in which instances are given, must be one
- of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
- or "file-list". Default is "jsonl" when using `gcs_source`. If a
- `bigquery_source` is provided, this is overriden to "bigquery".
+ The format in which instances are provided. Must be one
+ of the formats listed in `Model.supported_input_storage_formats`.
+ Default is "jsonl" when using `gcs_source`. If a `bigquery_source`
+ is provided, this is overridden to "bigquery".
gcs_destination_prefix: Optional[str] = None
The Google Cloud Storage location of the directory where the
output is to be written to. In the given directory a new
@@ -1879,29 +2522,33 @@ def batch_predict(
which as value has ```google.rpc.Status`` `__
containing only ``code`` and ``message`` fields.
bigquery_destination_prefix: Optional[str] = None
- The BigQuery project location where the output is to be
- written to. In the given project a new dataset is created
- with name
- ``prediction__`` where
- is made BigQuery-dataset-name compatible (for example, most
- special characters become underscores), and timestamp is in
- YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the
- dataset two tables will be created, ``predictions``, and
- ``errors``. If the Model has both ``instance`` and ``prediction``
- schemata defined then the tables have columns as follows:
- The ``predictions`` table contains instances for which the
- prediction succeeded, it has columns as per a concatenation
- of the Model's instance and prediction schemata. The
- ``errors`` table contains rows for which the prediction has
- failed, it has instance columns, as per the instance schema,
- followed by a single "errors" column, which as values has
- ```google.rpc.Status`` `__ represented as a STRUCT,
- and containing only ``code`` and ``message``.
+ The BigQuery URI to a project or table, up to 2000 characters long.
+ When only the project is specified, the Dataset and Table is created.
+ When the full table reference is specified, the Dataset must exist and
+ table must not exist. Accepted forms: ``bq://projectId`` or
+ ``bq://projectId.bqDatasetId`` or
+ ``bq://projectId.bqDatasetId.bqTableId``. If no Dataset is specified,
+ a new one is created with the name
+ ``prediction__``
+ where the table name is made BigQuery-dataset-name compatible
+ (for example, most special characters become underscores), and
+ timestamp is in YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601"
+ format. In the dataset two tables will be created, ``predictions``,
+ and ``errors``. If the Model has both ``instance`` and
+ ``prediction`` schemata defined then the tables have columns as
+ follows: The ``predictions`` table contains instances for which
+ the prediction succeeded, it has columns as per a concatenation
+ of the Model's instance and prediction schemata. The ``errors``
+ table contains rows for which the prediction has failed, it has
+ instance columns, as per the instance schema, followed by a single
+ "errors" column, which as values has ```google.rpc.Status`` `__
+ represented as a STRUCT, and containing only ``code`` and ``message``.
predictions_format: str = "jsonl"
- Required. The format in which Vertex AI gives the
- predictions, must be one of "jsonl", "csv", or "bigquery".
+ Required. The format in which Vertex AI outputs the
+ predictions, must be one of the formats specified in
+ `Model.supported_output_storage_formats`.
Default is "jsonl" when using `gcs_destination_prefix`. If a
- `bigquery_destination_prefix` is provided, this is overriden to
+ `bigquery_destination_prefix` is provided, this is overridden to
"bigquery".
model_parameters: Optional[Dict] = None
Optional. The parameters that govern the predictions. The schema of
@@ -1954,7 +2601,7 @@ def batch_predict(
a field of the `explanation_parameters` object is not populated, the
corresponding field of the `Model.explanation_parameters` object is inherited.
For more details, see `Ref docs `
- labels: Optional[dict] = None
+ labels: Optional[Dict[str, str]] = None
Optional. The labels with user-defined metadata to organize your
BatchPredictionJobs. Label keys and values can be no longer than
64 characters (Unicode codepoints), can only contain lowercase
@@ -1975,15 +2622,23 @@ def batch_predict(
If set, this Model and all sub-resources of this Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ batch_size (int):
+ Optional. The number of the records (e.g. instances) of the operation given in each batch
+ to a machine replica. Machine type, and size of a single record should be considered
+ when setting this parameter, higher value speeds up the batch operation's execution,
+ but too high value will result in a whole batch not fitting in a machine's memory,
+ and the whole operation will fail.
+ The default value is 64.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
- self.wait()
return jobs.BatchPredictionJob.create(
job_display_name=job_display_name,
- model_name=self.resource_name,
+ model_name=self,
instances_format=instances_format,
predictions_format=predictions_format,
gcs_source=gcs_source,
@@ -1996,6 +2651,7 @@ def batch_predict(
accelerator_count=accelerator_count,
starting_replica_count=starting_replica_count,
max_replica_count=max_replica_count,
+ batch_size=batch_size,
generate_explanation=generate_explanation,
explanation_metadata=explanation_metadata,
explanation_parameters=explanation_parameters,
@@ -2005,6 +2661,7 @@ def batch_predict(
credentials=credentials or self.credentials,
encryption_spec_key_name=encryption_spec_key_name,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@classmethod
@@ -2119,11 +2776,13 @@ def export_model(
Details of the completed export with output destination paths to
the artifacts or container image.
Raises:
- ValueError if model does not support exporting.
+ ValueError: If model does not support exporting.
- ValueError if invalid arguments or export formats are provided.
+ ValueError: If invalid arguments or export formats are provided.
"""
+ self.wait()
+
# Model does not support exporting
if not self.supported_export_formats:
raise ValueError(f"The model `{self.resource_name}` is not exportable.")
@@ -2172,8 +2831,8 @@ def export_model(
)
if image_destination:
- output_config.image_destination = gca_io_compat.ContainerRegistryDestination(
- output_uri=image_destination
+ output_config.image_destination = (
+ gca_io_compat.ContainerRegistryDestination(output_uri=image_destination)
)
_LOGGER.log_action_start_against_resource("Exporting", "model", self)
@@ -2192,3 +2851,655 @@ def export_model(
_LOGGER.log_action_completed_against_resource("model", "exported", self)
return json_format.MessageToDict(operation_future.metadata.output_info._pb)
+
+ @classmethod
+ @base.optional_sync()
+ def upload_xgboost_model_file(
+ cls,
+ model_file_path: str,
+ xgboost_version: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ instance_schema_uri: Optional[str] = None,
+ parameters_schema_uri: Optional[str] = None,
+ prediction_schema_uri: Optional[str] = None,
+ explanation_metadata: Optional[explain.ExplanationMetadata] = None,
+ explanation_parameters: Optional[explain.ExplanationParameters] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
+ encryption_spec_key_name: Optional[str] = None,
+ staging_bucket: Optional[str] = None,
+ sync=True,
+ upload_request_timeout: Optional[float] = None,
+ ) -> "Model":
+ """Uploads a model and returns a Model representing the uploaded Model
+ resource.
+
+ Note: This function is *experimental* and can be changed in the future.
+
+ Example usage::
+
+ my_model = Model.upload_xgboost_model_file(
+ model_file_path="iris.xgboost_model.bst"
+ )
+
+ Args:
+ model_file_path (str): Required. Local file path of the model.
+ xgboost_version (str): Optional. The version of the XGBoost serving container.
+ Supported versions: ["0.82", "0.90", "1.1", "1.2", "1.3", "1.4"].
+ If the version is not specified, the latest version is used.
+ display_name (str):
+ Optional. The display name of the Model. The name can be up to 128
+ characters long and can be consist of any UTF-8 characters.
+ description (str):
+ The description of the model.
+ instance_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single instance, which
+ are used in
+ ``PredictRequest.instances``,
+ ``ExplainRequest.instances``
+ and
+ ``BatchPredictionJob.input_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ parameters_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the parameters of prediction and
+ explanation via
+ ``PredictRequest.parameters``,
+ ``ExplainRequest.parameters``
+ and
+ ``BatchPredictionJob.model_parameters``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform, if no parameters are supported it is set to an
+ empty string. Note: The URI given on output will be
+ immutable and probably different, including the URI scheme,
+ than the one given on input. The output URI will point to a
+ location where the user only has a read access.
+ prediction_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single prediction
+ produced by this Model, which are returned via
+ ``PredictResponse.predictions``,
+ ``ExplainResponse.explanations``,
+ and
+ ``BatchPredictionJob.output_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ explanation_metadata (explain.ExplanationMetadata):
+ Optional. Metadata describing the Model's input and output for explanation.
+ Both `explanation_metadata` and `explanation_parameters` must be
+ passed together when used. For more details, see
+ `Ref docs `
+ explanation_parameters (explain.ExplanationParameters):
+ Optional. Parameters to configure explaining for Model's predictions.
+ For more details, see `Ref docs `
+ project: Optional[str]=None,
+ Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location: Optional[str]=None,
+ Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials: Optional[auth_credentials.Credentials]=None,
+ Custom credentials to use to upload this model. Overrides credentials
+ set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this Model and all sub-resources of this Model will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ staging_bucket (str):
+ Optional. Bucket to stage local model artifacts. Overrides
+ staging_bucket set in aiplatform.init.
+ upload_request_timeout (float):
+ Optional. The timeout for the upload request in seconds.
+ Returns:
+ model: Instantiated representation of the uploaded model resource.
+ Raises:
+ ValueError: If only `explanation_metadata` or `explanation_parameters`
+ is specified.
+ Also if model directory does not contain a supported model file.
+ """
+ if not display_name:
+ display_name = cls._generate_display_name("XGBoost model")
+
+ XGBOOST_SUPPORTED_MODEL_FILE_EXTENSIONS = [
+ ".pkl",
+ ".joblib",
+ ".bst",
+ ]
+
+ container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
+ region=location,
+ framework="xgboost",
+ framework_version=xgboost_version or "1.4",
+ accelerator="cpu",
+ )
+
+ model_file_path_obj = pathlib.Path(model_file_path)
+ if not model_file_path_obj.is_file():
+ raise ValueError(
+ f"model_file_path path must point to a file: '{model_file_path}'"
+ )
+
+ model_file_extension = model_file_path_obj.suffix
+ if model_file_extension not in XGBOOST_SUPPORTED_MODEL_FILE_EXTENSIONS:
+ _LOGGER.warning(
+ f"Only the following XGBoost model file extensions are currently supported: '{XGBOOST_SUPPORTED_MODEL_FILE_EXTENSIONS}'"
+ )
+ _LOGGER.warning(
+ "Treating the model file as a binary serialized XGBoost Booster."
+ )
+ model_file_extension = ".bst"
+
+ # Preparing model directory
+ # We cannot clean up the directory immediately after calling Model.upload since
+ # that call may be asynchronous and return before the model file has been read.
+ # To work around this, we make this method asynchronous (decorate with @base.optional_sync)
+ # but call Model.upload with sync=True.
+ with tempfile.TemporaryDirectory() as prepared_model_dir:
+ prepared_model_file_path = pathlib.Path(prepared_model_dir) / (
+ "model" + model_file_extension
+ )
+ shutil.copy(model_file_path_obj, prepared_model_file_path)
+
+ return cls.upload(
+ serving_container_image_uri=container_image_uri,
+ artifact_uri=prepared_model_dir,
+ display_name=display_name,
+ description=description,
+ instance_schema_uri=instance_schema_uri,
+ parameters_schema_uri=parameters_schema_uri,
+ prediction_schema_uri=prediction_schema_uri,
+ explanation_metadata=explanation_metadata,
+ explanation_parameters=explanation_parameters,
+ project=project,
+ location=location,
+ credentials=credentials,
+ labels=labels,
+ encryption_spec_key_name=encryption_spec_key_name,
+ staging_bucket=staging_bucket,
+ sync=True,
+ upload_request_timeout=upload_request_timeout,
+ )
+
+ @classmethod
+ @base.optional_sync()
+ def upload_scikit_learn_model_file(
+ cls,
+ model_file_path: str,
+ sklearn_version: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ instance_schema_uri: Optional[str] = None,
+ parameters_schema_uri: Optional[str] = None,
+ prediction_schema_uri: Optional[str] = None,
+ explanation_metadata: Optional[explain.ExplanationMetadata] = None,
+ explanation_parameters: Optional[explain.ExplanationParameters] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
+ encryption_spec_key_name: Optional[str] = None,
+ staging_bucket: Optional[str] = None,
+ sync=True,
+ upload_request_timeout: Optional[float] = None,
+ ) -> "Model":
+ """Uploads a model and returns a Model representing the uploaded Model
+ resource.
+
+ Note: This function is *experimental* and can be changed in the future.
+
+ Example usage::
+
+ my_model = Model.upload_scikit_learn_model_file(
+ model_file_path="iris.sklearn_model.joblib"
+ )
+
+ Args:
+ model_file_path (str): Required. Local file path of the model.
+ sklearn_version (str):
+ Optional. The version of the Scikit-learn serving container.
+ Supported versions: ["0.20", "0.22", "0.23", "0.24", "1.0"].
+ If the version is not specified, the latest version is used.
+ display_name (str):
+ Optional. The display name of the Model. The name can be up to 128
+ characters long and can be consist of any UTF-8 characters.
+ description (str):
+ The description of the model.
+ instance_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single instance, which
+ are used in
+ ``PredictRequest.instances``,
+ ``ExplainRequest.instances``
+ and
+ ``BatchPredictionJob.input_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ parameters_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the parameters of prediction and
+ explanation via
+ ``PredictRequest.parameters``,
+ ``ExplainRequest.parameters``
+ and
+ ``BatchPredictionJob.model_parameters``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform, if no parameters are supported it is set to an
+ empty string. Note: The URI given on output will be
+ immutable and probably different, including the URI scheme,
+ than the one given on input. The output URI will point to a
+ location where the user only has a read access.
+ prediction_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single prediction
+ produced by this Model, which are returned via
+ ``PredictResponse.predictions``,
+ ``ExplainResponse.explanations``,
+ and
+ ``BatchPredictionJob.output_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ explanation_metadata (explain.ExplanationMetadata):
+ Optional. Metadata describing the Model's input and output for explanation.
+ Both `explanation_metadata` and `explanation_parameters` must be
+ passed together when used. For more details, see
+ `Ref docs `
+ explanation_parameters (explain.ExplanationParameters):
+ Optional. Parameters to configure explaining for Model's predictions.
+ For more details, see `Ref docs `
+ project: Optional[str]=None,
+ Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location: Optional[str]=None,
+ Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials: Optional[auth_credentials.Credentials]=None,
+ Custom credentials to use to upload this model. Overrides credentials
+ set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this Model and all sub-resources of this Model will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ staging_bucket (str):
+ Optional. Bucket to stage local model artifacts. Overrides
+ staging_bucket set in aiplatform.init.
+ upload_request_timeout (float):
+ Optional. The timeout for the upload request in seconds.
+ Returns:
+ model: Instantiated representation of the uploaded model resource.
+ Raises:
+ ValueError: If only `explanation_metadata` or `explanation_parameters`
+ is specified.
+ Also if model directory does not contain a supported model file.
+ """
+ if not display_name:
+ display_name = cls._generate_display_name("Scikit-Learn model")
+
+ SKLEARN_SUPPORTED_MODEL_FILE_EXTENSIONS = [
+ ".pkl",
+ ".joblib",
+ ]
+
+ container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
+ region=location,
+ framework="sklearn",
+ framework_version=sklearn_version or "1.0",
+ accelerator="cpu",
+ )
+
+ model_file_path_obj = pathlib.Path(model_file_path)
+ if not model_file_path_obj.is_file():
+ raise ValueError(
+ f"model_file_path path must point to a file: '{model_file_path}'"
+ )
+
+ model_file_extension = model_file_path_obj.suffix
+ if model_file_extension not in SKLEARN_SUPPORTED_MODEL_FILE_EXTENSIONS:
+ _LOGGER.warning(
+ f"Only the following Scikit-learn model file extensions are currently supported: '{SKLEARN_SUPPORTED_MODEL_FILE_EXTENSIONS}'"
+ )
+ _LOGGER.warning(
+ "Treating the model file as a pickle serialized Scikit-learn model."
+ )
+ model_file_extension = ".pkl"
+
+ # Preparing model directory
+ # We cannot clean up the directory immediately after calling Model.upload since
+ # that call may be asynchronous and return before the model file has been read.
+ # To work around this, we make this method asynchronous (decorate with @base.optional_sync)
+ # but call Model.upload with sync=True.
+ with tempfile.TemporaryDirectory() as prepared_model_dir:
+ prepared_model_file_path = pathlib.Path(prepared_model_dir) / (
+ "model" + model_file_extension
+ )
+ shutil.copy(model_file_path_obj, prepared_model_file_path)
+
+ return cls.upload(
+ serving_container_image_uri=container_image_uri,
+ artifact_uri=prepared_model_dir,
+ display_name=display_name,
+ description=description,
+ instance_schema_uri=instance_schema_uri,
+ parameters_schema_uri=parameters_schema_uri,
+ prediction_schema_uri=prediction_schema_uri,
+ explanation_metadata=explanation_metadata,
+ explanation_parameters=explanation_parameters,
+ project=project,
+ location=location,
+ credentials=credentials,
+ labels=labels,
+ encryption_spec_key_name=encryption_spec_key_name,
+ staging_bucket=staging_bucket,
+ sync=True,
+ upload_request_timeout=upload_request_timeout,
+ )
+
+ @classmethod
+ def upload_tensorflow_saved_model(
+ cls,
+ saved_model_dir: str,
+ tensorflow_version: Optional[str] = None,
+ use_gpu: bool = False,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ instance_schema_uri: Optional[str] = None,
+ parameters_schema_uri: Optional[str] = None,
+ prediction_schema_uri: Optional[str] = None,
+ explanation_metadata: Optional[explain.ExplanationMetadata] = None,
+ explanation_parameters: Optional[explain.ExplanationParameters] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
+ encryption_spec_key_name: Optional[str] = None,
+ staging_bucket: Optional[str] = None,
+ sync=True,
+ upload_request_timeout: Optional[str] = None,
+ ) -> "Model":
+ """Uploads a model and returns a Model representing the uploaded Model
+ resource.
+
+ Note: This function is *experimental* and can be changed in the future.
+
+ Example usage::
+
+ my_model = Model.upload_scikit_learn_model_file(
+ model_file_path="iris.tensorflow_model.SavedModel"
+ )
+
+ Args:
+ saved_model_dir (str): Required.
+ Local directory of the Tensorflow SavedModel.
+ tensorflow_version (str):
+ Optional. The version of the Tensorflow serving container.
+ Supported versions: ["0.15", "2.1", "2.2", "2.3", "2.4", "2.5", "2.6", "2.7"].
+ If the version is not specified, the latest version is used.
+ use_gpu (bool): Whether to use GPU for model serving.
+ display_name (str):
+ Optional. The display name of the Model. The name can be up to 128
+ characters long and can be consist of any UTF-8 characters.
+ description (str):
+ The description of the model.
+ instance_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single instance, which
+ are used in
+ ``PredictRequest.instances``,
+ ``ExplainRequest.instances``
+ and
+ ``BatchPredictionJob.input_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ parameters_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the parameters of prediction and
+ explanation via
+ ``PredictRequest.parameters``,
+ ``ExplainRequest.parameters``
+ and
+ ``BatchPredictionJob.model_parameters``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform, if no parameters are supported it is set to an
+ empty string. Note: The URI given on output will be
+ immutable and probably different, including the URI scheme,
+ than the one given on input. The output URI will point to a
+ location where the user only has a read access.
+ prediction_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single prediction
+ produced by this Model, which are returned via
+ ``PredictResponse.predictions``,
+ ``ExplainResponse.explanations``,
+ and
+ ``BatchPredictionJob.output_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ explanation_metadata (explain.ExplanationMetadata):
+ Optional. Metadata describing the Model's input and output for explanation.
+ Both `explanation_metadata` and `explanation_parameters` must be
+ passed together when used. For more details, see
+ `Ref docs `
+ explanation_parameters (explain.ExplanationParameters):
+ Optional. Parameters to configure explaining for Model's predictions.
+ For more details, see `Ref docs `
+ project: Optional[str]=None,
+ Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location: Optional[str]=None,
+ Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials: Optional[auth_credentials.Credentials]=None,
+ Custom credentials to use to upload this model. Overrides credentials
+ set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this Model and all sub-resources of this Model will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ staging_bucket (str):
+ Optional. Bucket to stage local model artifacts. Overrides
+ staging_bucket set in aiplatform.init.
+ upload_request_timeout (float):
+ Optional. The timeout for the upload request in seconds.
+ Returns:
+ model: Instantiated representation of the uploaded model resource.
+ Raises:
+ ValueError: If only `explanation_metadata` or `explanation_parameters`
+ is specified.
+ Also if model directory does not contain a supported model file.
+ """
+ if not display_name:
+ display_name = cls._generate_display_name("Tensorflow model")
+
+ container_image_uri = aiplatform.helpers.get_prebuilt_prediction_container_uri(
+ region=location,
+ framework="tensorflow",
+ framework_version=tensorflow_version or "2.7",
+ accelerator="gpu" if use_gpu else "cpu",
+ )
+
+ return cls.upload(
+ serving_container_image_uri=container_image_uri,
+ artifact_uri=saved_model_dir,
+ display_name=display_name,
+ description=description,
+ instance_schema_uri=instance_schema_uri,
+ parameters_schema_uri=parameters_schema_uri,
+ prediction_schema_uri=prediction_schema_uri,
+ explanation_metadata=explanation_metadata,
+ explanation_parameters=explanation_parameters,
+ project=project,
+ location=location,
+ credentials=credentials,
+ labels=labels,
+ encryption_spec_key_name=encryption_spec_key_name,
+ staging_bucket=staging_bucket,
+ sync=sync,
+ upload_request_timeout=upload_request_timeout,
+ )
+
+ def list_model_evaluations(
+ self,
+ ) -> List["model_evaluation.ModelEvaluation"]:
+ """List all Model Evaluation resources associated with this model.
+
+ Example Usage:
+
+ my_model = Model(
+ model_name="projects/123/locations/us-central1/models/456"
+ )
+
+ my_evaluations = my_model.list_model_evaluations()
+
+ Returns:
+ List[model_evaluation.ModelEvaluation]: List of ModelEvaluation resources
+ for the model.
+ """
+
+ self.wait()
+
+ return model_evaluation.ModelEvaluation._list(
+ parent=self.resource_name,
+ credentials=self.credentials,
+ )
+
+ def get_model_evaluation(
+ self,
+ evaluation_id: Optional[str] = None,
+ ) -> Optional[model_evaluation.ModelEvaluation]:
+ """Returns a ModelEvaluation resource and instantiates its representation.
+ If no evaluation_id is passed, it will return the first evaluation associated
+ with this model.
+
+ Example usage:
+
+ my_model = Model(
+ model_name="projects/123/locations/us-central1/models/456"
+ )
+
+ my_evaluation = my_model.get_model_evaluation(
+ evaluation_id="789"
+ )
+
+ # If no arguments are passed, this returns the first evaluation for the model
+ my_evaluation = my_model.get_model_evaluation()
+
+ Args:
+ evaluation_id (str):
+ Optional. The ID of the model evaluation to retrieve.
+ Returns:
+ model_evaluation.ModelEvaluation: Instantiated representation of the
+ ModelEvaluation resource.
+ """
+
+ evaluations = self.list_model_evaluations()
+
+ if not evaluation_id:
+ if len(evaluations) > 1:
+ _LOGGER.warning(
+ f"Your model has more than one model evaluation, this is returning only one evaluation resource: {evaluations[0].resource_name}"
+ )
+ return evaluations[0] if evaluations else evaluations
+ else:
+ resource_uri_parts = self._parse_resource_name(self.resource_name)
+ evaluation_resource_name = (
+ model_evaluation.ModelEvaluation._format_resource_name(
+ **resource_uri_parts,
+ evaluation=evaluation_id,
+ )
+ )
+
+ return model_evaluation.ModelEvaluation(
+ evaluation_name=evaluation_resource_name,
+ credentials=self.credentials,
+ )
diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py
new file mode 100644
index 0000000000..f6fcc3a0af
--- /dev/null
+++ b/google/cloud/aiplatform/pipeline_jobs.py
@@ -0,0 +1,778 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import logging
+import time
+import re
+from typing import Any, Dict, List, Optional, Union
+
+from google.auth import credentials as auth_credentials
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.metadata import artifact
+from google.cloud.aiplatform.metadata import context
+from google.cloud.aiplatform.metadata import execution
+from google.cloud.aiplatform.metadata import constants as metadata_constants
+from google.cloud.aiplatform.metadata import experiment_resources
+from google.cloud.aiplatform.metadata import utils as metadata_utils
+from google.cloud.aiplatform.utils import yaml_utils
+from google.cloud.aiplatform.utils import pipeline_utils
+from google.protobuf import json_format
+
+from google.cloud.aiplatform.compat.types import (
+ pipeline_job as gca_pipeline_job,
+ pipeline_state as gca_pipeline_state,
+)
+
+_LOGGER = base.Logger(__name__)
+
+_PIPELINE_COMPLETE_STATES = set(
+ [
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED,
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED,
+ gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED,
+ ]
+)
+
+_PIPELINE_ERROR_STATES = set([gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED])
+
+# Pattern for valid names used as a Vertex resource name.
+_VALID_NAME_PATTERN = re.compile("^[a-z][-a-z0-9]{0,127}$")
+
+# Pattern for an Artifact Registry URL.
+_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
+
+
+def _get_current_time() -> datetime.datetime:
+ """Gets the current timestamp."""
+ return datetime.datetime.now()
+
+
+def _set_enable_caching_value(
+ pipeline_spec: Dict[str, Any], enable_caching: bool
+) -> None:
+ """Sets pipeline tasks caching options.
+
+ Args:
+ pipeline_spec (Dict[str, Any]):
+ Required. The dictionary of pipeline spec.
+ enable_caching (bool):
+ Required. Whether to enable caching.
+ """
+ for component in [pipeline_spec["root"]] + list(
+ pipeline_spec["components"].values()
+ ):
+ if "dag" in component:
+ for task in component["dag"]["tasks"].values():
+ task["cachingOptions"] = {"enableCache": enable_caching}
+
+
+class PipelineJob(
+ base.VertexAiStatefulResource,
+ experiment_resources._ExperimentLoggable,
+ experiment_loggable_schemas=(
+ experiment_resources._ExperimentLoggableSchema(
+ title=metadata_constants.SYSTEM_PIPELINE_RUN
+ ),
+ ),
+):
+
+ client_class = utils.PipelineJobClientWithOverride
+ _resource_noun = "pipelineJobs"
+ _delete_method = "delete_pipeline_job"
+ _getter_method = "get_pipeline_job"
+ _list_method = "list_pipeline_jobs"
+ _parse_resource_name_method = "parse_pipeline_job_path"
+ _format_resource_name_method = "pipeline_job_path"
+
+ # Required by the done() method
+ _valid_done_states = _PIPELINE_COMPLETE_STATES
+
+ def __init__(
+ self,
+ # TODO(b/223262536): Make the display_name parameter optional in the next major release
+ display_name: str,
+ template_path: str,
+ job_id: Optional[str] = None,
+ pipeline_root: Optional[str] = None,
+ parameter_values: Optional[Dict[str, Any]] = None,
+ enable_caching: Optional[bool] = None,
+ encryption_spec_key_name: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ failure_policy: Optional[str] = None,
+ ):
+ """Retrieves a PipelineJob resource and instantiates its
+ representation.
+
+ Args:
+ display_name (str):
+ Required. The user-defined name of this Pipeline.
+ template_path (str):
+ Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
+ can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
+ or an Artifact Registry URI (e.g.
+ "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest").
+ job_id (str):
+ Optional. The unique ID of the job run.
+ If not specified, pipeline name + timestamp will be used.
+ pipeline_root (str):
+ Optional. The root of the pipeline outputs. Default to be staging bucket.
+ parameter_values (Dict[str, Any]):
+ Optional. The mapping from runtime parameter names to its values that
+ control the pipeline run.
+ enable_caching (bool):
+ Optional. Whether to turn on caching for the run.
+
+ If this is not set, defaults to the compile time settings, which
+ are True for all tasks by default, while users may specify
+ different caching options for individual tasks.
+
+ If this is set, the setting applies to all tasks in the pipeline.
+
+ Overrides the compile time settings.
+ encryption_spec_key_name (str):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the job. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If this is set, then all
+ resources created by the PipelineJob will
+ be encrypted with the provided encryption key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The user defined metadata to organize PipelineJob.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create this PipelineJob.
+ Overrides credentials set in aiplatform.init.
+ project (str):
+ Optional. The project that you want to run this PipelineJob in. If not set,
+ the project set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to create PipelineJob. If not set,
+ location set in aiplatform.init will be used.
+ failure_policy (str):
+ Optional. The failure policy - "slow" or "fast".
+ Currently, the default of a pipeline is that the pipeline will continue to
+ run until no more tasks can be executed, also known as
+ PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
+ However, if a pipeline is set to
+ PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
+ it will stop scheduling any new tasks when a task has failed. Any
+ scheduled tasks will continue to completion.
+
+ Raises:
+ ValueError: If job_id or labels have incorrect format.
+ """
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
+ utils.validate_display_name(display_name)
+
+ if labels:
+ utils.validate_labels(labels)
+
+ super().__init__(project=project, location=location, credentials=credentials)
+
+ self._parent = initializer.global_config.common_location_path(
+ project=project, location=location
+ )
+
+ # this loads both .yaml and .json files because YAML is a superset of JSON
+ pipeline_json = yaml_utils.load_yaml(
+ template_path, self.project, self.credentials
+ )
+
+ # Pipeline_json can be either PipelineJob or PipelineSpec.
+ if pipeline_json.get("pipelineSpec") is not None:
+ pipeline_job = pipeline_json
+ pipeline_root = (
+ pipeline_root
+ or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
+ or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
+ or initializer.global_config.staging_bucket
+ )
+ else:
+ pipeline_job = {
+ "pipelineSpec": pipeline_json,
+ "runtimeConfig": {},
+ }
+ pipeline_root = (
+ pipeline_root
+ or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
+ or initializer.global_config.staging_bucket
+ )
+ builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
+ pipeline_job
+ )
+ builder.update_pipeline_root(pipeline_root)
+ builder.update_runtime_parameters(parameter_values)
+ builder.update_failure_policy(failure_policy)
+ runtime_config_dict = builder.build()
+
+ runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
+ json_format.ParseDict(runtime_config_dict, runtime_config)
+
+ pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
+ self.job_id = job_id or "{pipeline_name}-{timestamp}".format(
+ pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
+ .lstrip("-")
+ .rstrip("-"),
+ timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
+ )
+ if not _VALID_NAME_PATTERN.match(self.job_id):
+ raise ValueError(
+ f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. "
+ "Expecting an ID following the regex pattern "
+ f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
+ )
+
+ if enable_caching is not None:
+ _set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
+
+ pipeline_job_args = {
+ "display_name": display_name,
+ "pipeline_spec": pipeline_job["pipelineSpec"],
+ "labels": labels,
+ "runtime_config": runtime_config,
+ "encryption_spec": initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ ),
+ }
+
+ if _VALID_AR_URL.match(template_path):
+ pipeline_job_args["template_uri"] = template_path
+
+ self._gca_resource = gca_pipeline_job.PipelineJob(**pipeline_job_args)
+
+ @base.optional_sync()
+ def run(
+ self,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ sync: Optional[bool] = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> None:
+ """Run this configured PipelineJob and monitor the job until completion.
+
+ Args:
+ service_account (str):
+ Optional. Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ Optional. The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+
+ Private services access must already be configured for the network.
+ If left unspecified, the job is not peered with any network.
+ sync (bool):
+ Optional. Whether to execute this method synchronously. If False, this method will unblock and it will be executed in a concurrent Future.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ """
+ self.submit(
+ service_account=service_account,
+ network=network,
+ create_request_timeout=create_request_timeout,
+ )
+
+ self._block_until_complete()
+
+ def submit(
+ self,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
+ *,
+ experiment: Optional[Union[str, experiment_resources.Experiment]] = None,
+ ) -> None:
+ """Run this configured PipelineJob.
+
+ Args:
+ service_account (str):
+ Optional. Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ Optional. The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+
+ Private services access must already be configured for the network.
+ If left unspecified, the job is not peered with any network.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ experiment (Union[str, experiments_resource.Experiment]):
+ Optional. The Vertex AI experiment name or instance to associate to this PipelineJob.
+
+ Metrics produced by the PipelineJob as system.Metric Artifacts
+ will be associated as metrics to the current Experiment Run.
+
+ Pipeline parameters will be associated as parameters to the
+ current Experiment Run.
+ """
+ if service_account:
+ self._gca_resource.service_account = service_account
+
+ if network:
+ self._gca_resource.network = network
+
+ # Prevents logs from being supressed on TFX pipelines
+ if self._gca_resource.pipeline_spec.get("sdkVersion", "").startswith("tfx"):
+ _LOGGER.setLevel(logging.INFO)
+
+ if experiment:
+ self._validate_experiment(experiment)
+
+ _LOGGER.log_create_with_lro(self.__class__)
+
+ self._gca_resource = self.api_client.create_pipeline_job(
+ parent=self._parent,
+ pipeline_job=self._gca_resource,
+ pipeline_job_id=self.job_id,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_complete_with_getter(
+ self.__class__, self._gca_resource, "pipeline_job"
+ )
+
+ _LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())
+
+ if experiment:
+ self._associate_to_experiment(experiment)
+
+ def wait(self):
+ """Wait for this PipelineJob to complete."""
+ if self._latest_future is None:
+ self._block_until_complete()
+ else:
+ super().wait()
+
+ @property
+ def pipeline_spec(self):
+ return self._gca_resource.pipeline_spec
+
+ @property
+ def state(self) -> Optional[gca_pipeline_state.PipelineState]:
+ """Current pipeline state."""
+ self._sync_gca_resource()
+ return self._gca_resource.state
+
+ @property
+ def task_details(self) -> List[gca_pipeline_job.PipelineTaskDetail]:
+ self._sync_gca_resource()
+ return list(self._gca_resource.job_detail.task_details)
+
+ @property
+ def has_failed(self) -> bool:
+ """Returns True if pipeline has failed.
+
+ False otherwise.
+ """
+ return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED
+
+ def _dashboard_uri(self) -> str:
+ """Helper method to compose the dashboard uri where pipeline can be
+ viewed."""
+ fields = self._parse_resource_name(self.resource_name)
+ url = f"https://console.cloud.google.com/vertex-ai/locations/{fields['location']}/pipelines/runs/{fields['pipeline_job']}?project={fields['project']}"
+ return url
+
+ def _block_until_complete(self):
+ """Helper method to block and check on job until complete."""
+ # Used these numbers so failures surface fast
+ wait = 5 # start at five seconds
+ log_wait = 5
+ max_wait = 60 * 5 # 5 minute wait
+ multiplier = 2 # scale wait by 2 every iteration
+
+ previous_time = time.time()
+ while self.state not in _PIPELINE_COMPLETE_STATES:
+ current_time = time.time()
+ if current_time - previous_time >= log_wait:
+ _LOGGER.info(
+ "%s %s current state:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ self._gca_resource.state,
+ )
+ )
+ log_wait = min(log_wait * multiplier, max_wait)
+ previous_time = current_time
+ time.sleep(wait)
+
+ # Error is only populated when the job state is
+ # JOB_STATE_FAILED or JOB_STATE_CANCELLED.
+ if self._gca_resource.state in _PIPELINE_ERROR_STATES:
+ raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
+ else:
+ _LOGGER.log_action_completed_against_resource("run", "completed", self)
+
+ @classmethod
+ def get(
+ cls,
+ resource_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "PipelineJob":
+ """Get a Vertex AI Pipeline Job for the given resource_name.
+
+ Args:
+ resource_name (str):
+ Required. A fully-qualified resource name or ID.
+ project (str):
+ Optional. Project to retrieve dataset from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve dataset from. If not set,
+ location set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model.
+ Overrides credentials set in aiplatform.init.
+
+ Returns:
+ A Vertex AI PipelineJob.
+ """
+ self = cls._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=resource_name,
+ )
+
+ self._gca_resource = self._get_gca_resource(resource_name=resource_name)
+
+ return self
+
+ def cancel(self) -> None:
+ """Starts asynchronous cancellation on the PipelineJob. The server
+ makes a best effort to cancel the job, but success is not guaranteed.
+ On successful cancellation, the PipelineJob is not deleted; instead it
+ becomes a job with state set to `CANCELLED`.
+ """
+ self.api_client.cancel_pipeline_job(name=self.resource_name)
+
+ @classmethod
+ def list(
+ cls,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["PipelineJob"]:
+ """List all instances of this PipelineJob resource.
+
+ Example Usage:
+
+ aiplatform.PipelineJob.list(
+ filter='display_name="experiment_a27"',
+ order_by='create_time desc'
+ )
+
+ Args:
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ List[PipelineJob] - A list of PipelineJob resource objects
+ """
+
+ return cls._list_with_local_order(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+
+ def wait_for_resource_creation(self) -> None:
+ """Waits until resource has been created."""
+ self._wait_for_resource_creation()
+
+ def done(self) -> bool:
+ """Helper method that return True is PipelineJob is done. False otherwise."""
+ if not self._gca_resource:
+ return False
+
+ return self.state in _PIPELINE_COMPLETE_STATES
+
+ def _has_failed(self) -> bool:
+ """Return True if PipelineJob has Failed."""
+ if not self._gca_resource:
+ return False
+
+ return self.state in _PIPELINE_ERROR_STATES
+
+ def _get_context(self) -> context._Context:
+ """Returns the PipelineRun Context for this PipelineJob in the MetadataStore.
+
+ Returns:
+ System.PipelineRUn Context instance that represents this PipelineJob.
+
+ Raises:
+ RuntimeError if Pipeline has failed or system.PipelineRun context is not found.
+ """
+ self.wait_for_resource_creation()
+ pipeline_run_context = self._gca_resource.job_detail.pipeline_run_context
+
+ # PipelineJob context is created asynchronously so we need to poll until it exists.
+ while not self.done():
+ pipeline_run_context = self._gca_resource.job_detail.pipeline_run_context
+ if pipeline_run_context:
+ break
+ time.sleep(1)
+
+ if not pipeline_run_context:
+ if self._has_failed:
+ raise RuntimeError(
+ f"Cannot associate PipelineJob to Experiment: {self.gca_resource.error}"
+ )
+ else:
+ raise RuntimeError(
+ "Cannot associate PipelineJob to Experiment because PipelineJob context could not be found."
+ )
+
+ return context._Context(
+ resource=pipeline_run_context,
+ project=self.project,
+ location=self.location,
+ credentials=self.credentials,
+ )
+
+ @classmethod
+ def _query_experiment_row(
+ cls, node: context._Context
+ ) -> experiment_resources._ExperimentRow:
+ """Queries the PipelineJob metadata as an experiment run parameter and metric row.
+
+ Parameters are retrieved from the system.Run Execution.metadata of the PipelineJob.
+
+ Metrics are retrieved from the system.Metric Artifacts.metadata produced by this PipelineJob.
+
+ Args:
+ node (context._Context):
+ Required. System.PipelineRun context that represents a PipelineJob Run.
+ Returns:
+ Experiment run row representing this PipelineJob.
+ """
+
+ system_run_executions = execution.Execution.list(
+ project=node.project,
+ location=node.location,
+ credentials=node.credentials,
+ filter=metadata_utils._make_filter_string(
+ in_context=[node.resource_name],
+ schema_title=metadata_constants.SYSTEM_RUN,
+ ),
+ )
+
+ metric_artifacts = artifact.Artifact.list(
+ project=node.project,
+ location=node.location,
+ credentials=node.credentials,
+ filter=metadata_utils._make_filter_string(
+ in_context=[node.resource_name],
+ schema_title=metadata_constants.SYSTEM_METRICS,
+ ),
+ )
+
+ row = experiment_resources._ExperimentRow(
+ experiment_run_type=node.schema_title, name=node.display_name
+ )
+
+ if system_run_executions:
+ row.params = {
+ key[len(metadata_constants.PIPELINE_PARAM_PREFIX) :]: value
+ for key, value in system_run_executions[0].metadata.items()
+ }
+ row.state = system_run_executions[0].state.name
+
+ for metric_artifact in metric_artifacts:
+ if row.metrics:
+ row.metrics.update(metric_artifact.metadata)
+ else:
+ row.metrics = metric_artifact.metadata
+
+ return row
+
+ def clone(
+ self,
+ display_name: Optional[str] = None,
+ job_id: Optional[str] = None,
+ pipeline_root: Optional[str] = None,
+ parameter_values: Optional[Dict[str, Any]] = None,
+ enable_caching: Optional[bool] = None,
+ encryption_spec_key_name: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ ) -> "PipelineJob":
+ """Returns a new PipelineJob object with the same settings as the original one.
+
+ Args:
+ display_name (str):
+ Optional. The user-defined name of this cloned Pipeline.
+ If not specified, original pipeline display name will be used.
+ job_id (str):
+ Optional. The unique ID of the job run.
+ If not specified, "cloned" + pipeline name + timestamp will be used.
+ pipeline_root (str):
+ Optional. The root of the pipeline outputs. Default to be the same
+ staging bucket as original pipeline.
+ parameter_values (Dict[str, Any]):
+ Optional. The mapping from runtime parameter names to its values that
+ control the pipeline run. Defaults to be the same values as original
+ PipelineJob.
+ enable_caching (bool):
+ Optional. Whether to turn on caching for the run.
+ If this is not set, defaults to be the same as original pipeline.
+ If this is set, the setting applies to all tasks in the pipeline.
+ encryption_spec_key_name (str):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the job. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute resource is created.
+ If this is set, then all
+ resources created by the PipelineJob will
+ be encrypted with the provided encryption key.
+ If not specified, encryption_spec of original PipelineJob will be used.
+ labels (Dict[str, str]):
+ Optional. The user defined metadata to organize PipelineJob.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to create this PipelineJob.
+ Overrides credentials set in aiplatform.init.
+ project (str):
+ Optional. The project that you want to run this PipelineJob in.
+ If not set, the project set in original PipelineJob will be used.
+ location (str):
+ Optional. Location to create PipelineJob.
+ If not set, location set in original PipelineJob will be used.
+
+ Returns:
+ A Vertex AI PipelineJob.
+
+ Raises:
+ ValueError: If job_id or labels have incorrect format.
+ """
+ ## Initialize an empty PipelineJob
+ if not project:
+ project = self.project
+ if not location:
+ location = self.location
+ if not credentials:
+ credentials = self.credentials
+
+ cloned = self.__class__._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ )
+ cloned._parent = initializer.global_config.common_location_path(
+ project=project, location=location
+ )
+
+ ## Get gca_resource from original PipelineJob
+ pipeline_job = json_format.MessageToDict(self._gca_resource._pb)
+
+ ## Set pipeline_spec
+ pipeline_spec = pipeline_job["pipelineSpec"]
+ if "deploymentConfig" in pipeline_spec:
+ del pipeline_spec["deploymentConfig"]
+
+ ## Set caching
+ if enable_caching is not None:
+ _set_enable_caching_value(pipeline_spec, enable_caching)
+
+ ## Set job_id
+ pipeline_name = pipeline_spec["pipelineInfo"]["name"]
+ cloned.job_id = job_id or "cloned-{pipeline_name}-{timestamp}".format(
+ pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
+ .lstrip("-")
+ .rstrip("-"),
+ timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
+ )
+ if not _VALID_NAME_PATTERN.match(cloned.job_id):
+ raise ValueError(
+ f"Generated job ID: {cloned.job_id} is illegal as a Vertex pipelines job ID. "
+ "Expecting an ID following the regex pattern "
+ f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
+ )
+
+ ## Set display_name, labels and encryption_spec
+ if display_name:
+ utils.validate_display_name(display_name)
+ elif not display_name and "displayName" in pipeline_job:
+ display_name = pipeline_job["displayName"]
+
+ if labels:
+ utils.validate_labels(labels)
+ elif not labels and "labels" in pipeline_job:
+ labels = pipeline_job["labels"]
+
+ if encryption_spec_key_name or "encryptionSpec" not in pipeline_job:
+ encryption_spec = initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ )
+ else:
+ encryption_spec = pipeline_job["encryptionSpec"]
+
+ ## Set runtime_config
+ builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
+ pipeline_job
+ )
+ builder.update_pipeline_root(pipeline_root)
+ builder.update_runtime_parameters(parameter_values)
+ runtime_config_dict = builder.build()
+ runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
+ json_format.ParseDict(runtime_config_dict, runtime_config)
+
+ ## Create gca_resource for cloned PipelineJob
+ cloned._gca_resource = gca_pipeline_job.PipelineJob(
+ display_name=display_name,
+ pipeline_spec=pipeline_spec,
+ labels=labels,
+ runtime_config=runtime_config,
+ encryption_spec=encryption_spec,
+ )
+
+ return cloned
diff --git a/google/cloud/aiplatform/schema.py b/google/cloud/aiplatform/schema.py
index a1da75d9e6..96a7a50bbd 100644
--- a/google/cloud/aiplatform/schema.py
+++ b/google/cloud/aiplatform/schema.py
@@ -23,6 +23,7 @@ class definition:
custom_task = "gs://google-cloud-aiplatform/schema/trainingjob/definition/custom_task_1.0.0.yaml"
automl_tabular = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml"
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
+ seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
diff --git a/google/cloud/aiplatform/tensorboard/__init__.py b/google/cloud/aiplatform/tensorboard/__init__.py
index a6fbe4122f..58eb7c3640 100644
--- a/google/cloud/aiplatform/tensorboard/__init__.py
+++ b/google/cloud/aiplatform/tensorboard/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2021 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,3 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
+from google.cloud.aiplatform.tensorboard.tensorboard_resource import (
+ Tensorboard,
+ TensorboardExperiment,
+ TensorboardRun,
+ TensorboardTimeSeries,
+)
+
+
+__all__ = (
+ "Tensorboard",
+ "TensorboardExperiment",
+ "TensorboardRun",
+ "TensorboardTimeSeries",
+)
diff --git a/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py
new file mode 100644
index 0000000000..5fd3f58e4d
--- /dev/null
+++ b/google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py
@@ -0,0 +1,596 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Upload profile sessions to Vertex AI Tensorboard."""
+from collections import defaultdict
+import datetime
+import functools
+import os
+import re
+from typing import (
+ DefaultDict,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
+
+import grpc
+from tensorboard.uploader import upload_tracker
+from tensorboard.uploader import util
+from tensorboard.uploader.proto import server_info_pb2
+from tensorboard.util import tb_logging
+import tensorflow as tf
+
+from google.cloud import storage
+from google.cloud.aiplatform.compat.services import tensorboard_service_client
+from google.cloud.aiplatform.compat.types import tensorboard_data
+from google.cloud.aiplatform.compat.types import tensorboard_service
+from google.cloud.aiplatform.compat.types import tensorboard_time_series
+from google.cloud.aiplatform.tensorboard import uploader_utils
+from google.protobuf import timestamp_pb2 as timestamp
+
+TensorboardServiceClient = tensorboard_service_client.TensorboardServiceClient
+
+logger = tb_logging.get_logger()
+
+
+class ProfileRequestSender(uploader_utils.RequestSender):
+ """Helper class for building requests for the profiler plugin.
+
+ While the profile plugin does create event files when a profile run is performed
+ for a new training run, these event files do not contain any values
+ like other events do. Instead, the plugin will create subdirectories and profiling
+ files within these subdirectories.
+
+ To verify the plugin, subdirectories need to be searched to confirm valid
+ profile directories and files.
+
+ This class is not threadsafe. Use external synchronization if
+ calling its methods concurrently.
+ """
+
+ PLUGIN_NAME = "profile"
+ PROFILE_PATH = "plugins/profile"
+
+ def __init__(
+ self,
+ experiment_resource_name: str,
+ api: TensorboardServiceClient,
+ upload_limits: server_info_pb2.UploadLimits,
+ blob_rpc_rate_limiter: util.RateLimiter,
+ blob_storage_bucket: storage.Bucket,
+ blob_storage_folder: str,
+ tracker: upload_tracker.UploadTracker,
+ logdir: str,
+ source_bucket: Optional[storage.Bucket],
+ ):
+ """Constructs ProfileRequestSender for the given experiment resource.
+
+ Args:
+ experiment_resource_name (str):
+ Required. Name of the experiment resource of the form:
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
+ api (TensorboardServiceClient):
+ Required. Tensorboard service stub used to interact with experiment resource.
+ upload_limits (server_info_pb2.UploadLimits):
+ Required. Upload limits for for api calls.
+ blob_rpc_rate_limiter (util.RateLimiter):
+ Required. A `RateLimiter` to use to limit write RPC frequency.
+ Note this limit applies at the level of single RPCs in the Scalar and
+ Tensor case, but at the level of an entire blob upload in the Blob
+ case-- which may require a few preparatory RPCs and a stream of chunks.
+ Note the chunk stream is internally rate-limited by backpressure from
+ the server, so it is not a concern that we do not explicitly rate-limit
+ within the stream here.
+ blob_storage_bucket (storage.Bucket):
+ Required. A `storage.Bucket` to send all blob files to.
+ blob_storage_folder (str):
+ Required. Name of the folder to save blob files to within the blob_storage_bucket.
+ tracker (upload_tracker.UploadTracker):
+ Required. Upload tracker to track information about uploads.
+ logdir (str).
+ Required. The log directory for the request sender to search.
+ source_bucket (Optional[storage.Bucket]):
+ Optional. The user's specified `storage.Bucket` to save events to. If a user is uploading from
+ a local directory, this can be None.
+ """
+ self._experiment_resource_name = experiment_resource_name
+ self._api = api
+ self._logdir = logdir
+ self._tag_metadata = {}
+ self._tracker = tracker
+ self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
+ experiment_resource_name=experiment_resource_name, api=api
+ )
+
+ self._run_to_file_request_sender: Dict[str, _FileRequestSender] = {}
+ self._run_to_profile_loaders: Dict[str, _ProfileSessionLoader] = {}
+
+ self._file_request_sender_factory = functools.partial(
+ _FileRequestSender,
+ api=api,
+ rpc_rate_limiter=blob_rpc_rate_limiter,
+ max_blob_request_size=upload_limits.max_blob_request_size,
+ max_blob_size=upload_limits.max_blob_size,
+ blob_storage_bucket=blob_storage_bucket,
+ source_bucket=source_bucket,
+ blob_storage_folder=blob_storage_folder,
+ tracker=self._tracker,
+ )
+
+ def _is_valid_event(self, run_name: str) -> bool:
+ """Determines whether a valid profile session has occurred.
+
+ Profile events are determined by whether a corresponding directory has
+ been created for the profile plugin.
+
+ Args:
+ run_name (str):
+ Required. String representing the run name.
+
+ Returns:
+ True if is a valid profile plugin event, False otherwise.
+ """
+
+ return tf.io.gfile.isdir(self._profile_dir(run_name))
+
+ def _profile_dir(self, run_name: str) -> str:
+ """Converts run name to full profile path.
+
+ Args:
+ run_name (str):
+ Required. Name of training run.
+
+ Returns:
+ Full path for run name.
+ """
+ return os.path.join(self._logdir, run_name, self.PROFILE_PATH)
+
+ def send_request(self, run_name: str):
+ """Accepts run_name and sends an RPC request if an event is detected.
+
+ Args:
+ run_name (str):
+ Required. Name of the training run.
+ """
+
+ if not self._is_valid_event(run_name):
+ logger.warning("No such profile run for %s", run_name)
+ return
+
+ # Create a profiler loader if one is not created.
+ # This will store any new runs that occur within the training.
+ if run_name not in self._run_to_profile_loaders:
+ self._run_to_profile_loaders[run_name] = _ProfileSessionLoader(
+ self._profile_dir(run_name)
+ )
+
+ tb_run = self._one_platform_resource_manager.get_run_resource_name(run_name)
+
+ if run_name not in self._run_to_file_request_sender:
+ self._run_to_file_request_sender[
+ run_name
+ ] = self._file_request_sender_factory(tb_run)
+
+ # Loop through any of the profiling sessions within this training run.
+ # A training run can have multiple profile sessions.
+ for prof_session, files in self._run_to_profile_loaders[
+ run_name
+ ].prof_sessions_to_files():
+ event_time = datetime.datetime.strptime(prof_session, "%Y_%m_%d_%H_%M_%S")
+ event_timestamp = timestamp.Timestamp().FromDatetime(event_time)
+
+ # Implicit flush to any files after they are uploaded.
+ self._run_to_file_request_sender[run_name].add_files(
+ files=files,
+ tag=prof_session,
+ plugin=self.PLUGIN_NAME,
+ event_timestamp=event_timestamp,
+ )
+
+
+class _ProfileSessionLoader(object):
+ """Loader for a profile session within a training run.
+
+ The term 'session' refers to an instance of a profile, where
+ one may have multiple profile sessions under a training run.
+ """
+
+ # A regular expression for the naming of a profiling path.
+ PROF_PATH_REGEX = r".*\/plugins\/profile\/[0-9]{4}_[0-9]{2}_[0-9]{2}_[0-9]{2}_[0-9]{2}_[0-9]{2}\/?$"
+
+ def __init__(
+ self,
+ path: str,
+ ):
+ """Create a loader for profiling sessions with a training run.
+
+ Args:
+ path (str):
+ Required. Path to the training run, which contains one or more profiling
+ sessions. Path should end with '/profile/plugin'.
+ """
+ self._path = path
+ self._prof_session_to_files: DefaultDict[str, Set[str]] = defaultdict(set)
+
+ def _path_filter(self, path: str) -> bool:
+ """Determine which paths we should upload.
+
+ Paths written by profiler should be of form:
+ /some/path/to/dir/plugins/profile/%Y_%m_%d_%H_%M_%S
+
+ Args:
+ path (str):
+ Required. String representing a full directory path.
+
+ Returns:
+ True if valid path and path matches the filter, False otherwise.
+ """
+ return tf.io.gfile.isdir(path) and re.match(self.PROF_PATH_REGEX, path)
+
+ def _path_to_files(self, prof_session: str, path: str) -> List[str]:
+ """Generates files that have not yet been tracked.
+
+ Files are generated by the profiler and are added to an internal
+ dictionary. For files that have not yet been uploaded, we return these
+ files.
+
+ Args:
+ prof_session (str):
+ Required. The profiling session name.
+ path (str):
+ Required. Directory of the profiling session.
+
+ Returns:
+ files (List[str]):
+ Files that have not been tracked yet.
+ """
+
+ files = []
+ for prof_file in tf.io.gfile.listdir(path):
+ full_file_path = os.path.join(path, prof_file)
+ if full_file_path not in self._prof_session_to_files[prof_session]:
+ files.append(full_file_path)
+
+ self._prof_session_to_files[prof_session].update(files)
+ return files
+
+ def prof_sessions_to_files(self) -> Generator[Tuple[str, List[str]], None, None]:
+ """Map files to a profile session.
+
+ Yields:
+ A tuple containing the profiling session name and a list of files
+ that have not yet been tracked.
+ """
+
+ prof_sessions = tf.io.gfile.listdir(self._path)
+
+ for prof_session in prof_sessions:
+ # Remove trailing slashes in path names
+ prof_session = (
+ prof_session if not prof_session.endswith("/") else prof_session[:-1]
+ )
+
+ full_path = os.path.join(self._path, prof_session)
+ if not self._path_filter(full_path):
+ continue
+
+ files = self._path_to_files(prof_session, full_path)
+
+ if files:
+ yield (prof_session, files)
+
+
+class _FileRequestSender(object):
+ """Uploader for file based items.
+
+ This sender is closely related to the `_BlobRequestSender`, however it expects
+ file paths instead of blob files, so that data is not directly read in and instead
+ files are moved between buckets. Additionally, this sender does not take event files
+ as the other request sender objects do. The sender takes files from either local storage
+ or a gcs bucket and uploads to the tensorboard bucket.
+
+ This class is not threadsafe. Use external synchronization if calling its
+ methods concurrently.
+ """
+
+ def __init__(
+ self,
+ run_resource_id: str,
+ api: TensorboardServiceClient,
+ rpc_rate_limiter: util.RateLimiter,
+ max_blob_request_size: int,
+ max_blob_size: int,
+ blob_storage_bucket: storage.Bucket,
+ blob_storage_folder: str,
+ tracker: upload_tracker.UploadTracker,
+ source_bucket: Optional[storage.Bucket] = None,
+ ):
+ """Creates a _FileRequestSender object.
+
+ Args:
+ run_resource_id (str):
+ Required. Name of the run resource of the form:
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}
+ api (TensorboardServiceClient):
+ Required. TensorboardServiceStub for calling various tensorboard services.
+ rpc_rate_limiter (util.RateLimiter):
+ Required. A `RateLimiter` to use to limit write RPC frequency.
+ Note this limit applies at the level of single RPCs in the Scalar and
+ Tensor case, but at the level of an entire blob upload in the Blob
+ case-- which may require a few preparatory RPCs and a stream of chunks.
+ Note the chunk stream is internally rate-limited by backpressure from
+ the server, so it is not a concern that we do not explicitly rate-limit
+ within the stream here.
+ max_blob_request_size (int):
+ Required. Maximum request size to send.
+ max_blob_size (int):
+ Required. Maximum size in bytes of the blobs to send.
+ blob_storage_bucket (storage.Bucket):
+ Required. Bucket to send event files to.
+ blob_storage_folder (str):
+ Required. The folder to save blob files to.
+ tracker (upload_tracker.UploadTracker):
+ Required. Track any uploads to backend.
+ source_bucket (storage.Bucket):
+ Optional. The source bucket to upload from. If not set, use local filesystem instead.
+ """
+ self._run_resource_id = run_resource_id
+ self._api = api
+ self._rpc_rate_limiter = rpc_rate_limiter
+ self._max_blob_request_size = max_blob_request_size
+ self._max_blob_size = max_blob_size
+ self._tracker = tracker
+ self._time_series_resource_manager = uploader_utils.TimeSeriesResourceManager(
+ run_resource_id, api
+ )
+
+ self._bucket = blob_storage_bucket
+ self._folder = blob_storage_folder
+ self._source_bucket = source_bucket
+
+ self._new_request()
+
+ def _new_request(self):
+ """Declares the previous event complete."""
+ self._files = []
+ self._tag = None
+ self._plugin = None
+ self._event_timestamp = None
+
+ def add_files(
+ self,
+ files: List[str],
+ tag: str,
+ plugin: str,
+ event_timestamp: timestamp.Timestamp,
+ ):
+ """Attempts to add the given file to the current request.
+
+ If a file does not exist, the file is ignored and the rest of the
+ files are checked to ensure the remaining files exist. After checking
+ the files, an rpc is immediately sent.
+
+ Files are flushed immediately, opposed to some of the other request senders.
+
+ Args:
+ files (List[str]):
+ Required. The paths of the files to upload.
+ tag (str):
+ Required. A unique identifier for the blob sequence.
+ plugin (str):
+ Required. Name of the plugin making the request.
+ event_timestamp (timestamp.Timestamp):
+ Required. The time the event is created.
+ """
+
+ for prof_file in files:
+ if not tf.io.gfile.exists(prof_file):
+ logger.warning(
+ "The file provided does not exist. "
+ "Will not be uploading file %s.",
+ prof_file,
+ )
+ else:
+ self._files.append(prof_file)
+
+ self._tag = tag
+ self._plugin = plugin
+ self._event_timestamp = event_timestamp
+ self.flush()
+ self._new_request()
+
+ def flush(self):
+ """Sends the current file fully, and clears it to make way for the next."""
+ if not self._files:
+ return
+
+ time_series_proto = self._time_series_resource_manager.get_or_create(
+ self._tag,
+ lambda: tensorboard_time_series.TensorboardTimeSeries(
+ display_name=self._tag,
+ value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE,
+ plugin_name=self._plugin,
+ ),
+ )
+ m = re.match(
+ ".*/tensorboards/(.*)/experiments/(.*)/runs/(.*)/timeSeries/(.*)",
+ time_series_proto.name,
+ )
+ blob_path_prefix = "tensorboard-{}/{}/{}/{}".format(m[1], m[2], m[3], m[4])
+ blob_path_prefix = (
+ "{}/{}".format(self._folder, blob_path_prefix)
+ if self._folder
+ else blob_path_prefix
+ )
+ sent_blob_ids = []
+
+ for prof_file in self._files:
+ self._rpc_rate_limiter.tick()
+ file_size = tf.io.gfile.stat(prof_file).length
+ with self._tracker.blob_tracker(file_size) as blob_tracker:
+ if not self._file_too_large(prof_file):
+ blob_id = self._upload(prof_file, blob_path_prefix)
+ sent_blob_ids.append(str(blob_id))
+ blob_tracker.mark_uploaded(blob_id is not None)
+
+ data_point = tensorboard_data.TimeSeriesDataPoint(
+ blobs=tensorboard_data.TensorboardBlobSequence(
+ values=[
+ tensorboard_data.TensorboardBlob(id=blob_id)
+ for blob_id in sent_blob_ids
+ ]
+ ),
+ wall_time=self._event_timestamp,
+ )
+
+ time_series_data_proto = tensorboard_data.TimeSeriesData(
+ tensorboard_time_series_id=time_series_proto.name.split("/")[-1],
+ value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE,
+ values=[data_point],
+ )
+ request = tensorboard_service.WriteTensorboardRunDataRequest(
+ time_series_data=[time_series_data_proto]
+ )
+
+ _prune_empty_time_series_from_blob(request)
+ if not request.time_series_data:
+ return
+
+ with uploader_utils.request_logger(request):
+ try:
+ self._api.write_tensorboard_run_data(
+ tensorboard_run=self._run_resource_id,
+ time_series_data=request.time_series_data,
+ )
+ except grpc.RpcError as e:
+ logger.error("Upload call failed with error %s", e)
+
+ def _file_too_large(self, filename: str) -> bool:
+ """Determines if a file is too large to upload.
+
+ Args:
+ filename (str):
+ Required. The filename to check.
+
+ Returns:
+ True if too large, False otherwise.
+ """
+
+ file_size = tf.io.gfile.stat(filename).length
+ if file_size > self._max_blob_size:
+ logger.warning(
+ "Blob too large; skipping. Size %d exceeds limit of %d bytes.",
+ file_size,
+ self._max_blob_size,
+ )
+ return True
+ return False
+
+ def _upload(self, filename: str, blob_path_prefix: Optional[str] = None) -> str:
+ """Copies files between either a local directory or a bucket and the tenant bucket.
+
+ Args:
+ filename (str):
+ Required. The full path of the file to upload.
+ blob_path_prefix (str):
+ Optional. Path prefix for the location to store the file.
+
+ Returns:
+ blob_id (str):
+ The base path of the file.
+ """
+ blob_id = os.path.basename(filename)
+ blob_path = (
+ "{}/{}".format(blob_path_prefix, blob_id) if blob_path_prefix else blob_id
+ )
+
+ # Source bucket indicates files are storage on cloud storage
+ if self._source_bucket:
+ self._copy_between_buckets(filename, blob_path)
+ else:
+ self._upload_from_local(filename, blob_path)
+
+ return blob_id
+
+ def _copy_between_buckets(self, filename: str, blob_path: str):
+ """Move files between the user's bucket and the tenant bucket.
+
+ Args:
+ filename (str):
+ Required. Full path of the file to upload.
+ blob_path (str):
+ Required. A bucket path to upload the file to.
+
+ """
+ blob_name = _get_blob_from_file(filename)
+
+ source_blob = self._source_bucket.blob(blob_name)
+
+ self._source_bucket.copy_blob(
+ source_blob,
+ self._bucket,
+ blob_path,
+ )
+
+ def _upload_from_local(self, filename: str, blob_path: str):
+ """Uploads a local file to the tenant bucket.
+
+ Args:
+ filename (str):
+ Required. Full path of the file to upload.
+ blob_path (str):
+ Required. A bucket path to upload the file to.a
+ """
+ blob = self._bucket.blob(blob_path)
+ blob.upload_from_filename(filename)
+
+
+def _get_blob_from_file(fp: str) -> Optional[str]:
+ """Gets blob name from a storage bucket.
+
+ Args:
+ fp (str):
+ Required. A file path.
+
+ Returns:
+ blob_name (str):
+ Optional. Base blob file name if it exists, else None
+ """
+ m = re.match(r"gs:\/\/.*?\/(.*)", fp)
+ if not m:
+ logger.warning("Could not get the blob name from file %s", fp)
+ return None
+ return m[1]
+
+
+def _prune_empty_time_series_from_blob(
+ request: tensorboard_service.WriteTensorboardRunDataRequest,
+):
+ """Removes empty time_series from request if there are no blob files.'
+
+ Args:
+ request (tensorboard_service.WriteTensorboardRunDataRequest):
+ Required. A write request for blob files.
+ """
+ for time_series_idx, time_series_data in reversed(
+ list(enumerate(request.time_series_data))
+ ):
+ if not any(x.blobs for x in time_series_data.values):
+ del request.time_series_data[time_series_idx]
diff --git a/google/cloud/aiplatform/tensorboard/tensorboard_resource.py b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py
new file mode 100644
index 0000000000..76e89ca6bd
--- /dev/null
+++ b/google/cloud/aiplatform/tensorboard/tensorboard_resource.py
@@ -0,0 +1,1289 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+from google.auth import credentials as auth_credentials
+from google.protobuf import field_mask_pb2
+from google.protobuf import timestamp_pb2
+
+from google.cloud.aiplatform import base
+from google.cloud.aiplatform import initializer
+from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.compat.types import tensorboard as gca_tensorboard
+from google.cloud.aiplatform.compat.types import (
+ tensorboard_data as gca_tensorboard_data,
+)
+from google.cloud.aiplatform.compat.types import (
+ tensorboard_experiment as gca_tensorboard_experiment,
+)
+from google.cloud.aiplatform.compat.types import tensorboard_run as gca_tensorboard_run
+from google.cloud.aiplatform.compat.types import (
+ tensorboard_service as gca_tensorboard_service,
+)
+from google.cloud.aiplatform.compat.types import (
+ tensorboard_time_series as gca_tensorboard_time_series,
+)
+
+_LOGGER = base.Logger(__name__)
+
+
+class _TensorboardServiceResource(base.VertexAiResourceNounWithFutureManager):
+ client_class = utils.TensorboardClientWithOverride
+
+
+class Tensorboard(_TensorboardServiceResource):
+ """Managed tensorboard resource for Vertex AI."""
+
+ _resource_noun = "tensorboards"
+ _getter_method = "get_tensorboard"
+ _list_method = "list_tensorboards"
+ _delete_method = "delete_tensorboard"
+ _parse_resource_name_method = "parse_tensorboard_path"
+ _format_resource_name_method = "tensorboard_path"
+
+ def __init__(
+ self,
+ tensorboard_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing managed tensorboard given a tensorboard name or ID.
+
+ Args:
+ tensorboard_name (str):
+ Required. A fully-qualified tensorboard resource name or tensorboard ID.
+ Example: "projects/123/locations/us-central1/tensorboards/456" or
+ "456" when project and location are initialized or passed.
+ project (str):
+ Optional. Project to retrieve tensorboard from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve tensorboard from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Tensorboard. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=tensorboard_name,
+ )
+ self._gca_resource = self._get_gca_resource(resource_name=tensorboard_name)
+
+ @classmethod
+ def create(
+ cls,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ encryption_spec_key_name: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
+ ) -> "Tensorboard":
+ """Creates a new tensorboard.
+
+ Example Usage:
+
+ tb = aiplatform.Tensorboard.create(
+ display_name='my display name',
+ description='my description',
+ labels={
+ 'key1': 'value1',
+ 'key2': 'value2'
+ }
+ )
+
+ Args:
+ display_name (str):
+ Optional. The user-defined name of the Tensorboard.
+ The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+ description (str):
+ Optional. Description of this Tensorboard.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
+ project (str):
+ Optional. Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec_key_name (str):
+ Optional. Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the tensorboard. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this Tensorboard and all sub-resources of this Tensorboard will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
+ Returns:
+ tensorboard (Tensorboard):
+ Instantiated representation of the managed tensorboard resource.
+ """
+ if not display_name:
+ display_name = cls._generate_display_name()
+
+ utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ parent = initializer.global_config.common_location_path(
+ project=project, location=location
+ )
+
+ encryption_spec = initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name
+ )
+
+ gapic_tensorboard = gca_tensorboard.Tensorboard(
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ encryption_spec=encryption_spec,
+ )
+
+ create_tensorboard_lro = api_client.create_tensorboard(
+ parent=parent,
+ tensorboard=gapic_tensorboard,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_with_lro(cls, create_tensorboard_lro)
+
+ created_tensorboard = create_tensorboard_lro.result()
+
+ _LOGGER.log_create_complete(cls, created_tensorboard, "tb")
+
+ return cls(
+ tensorboard_name=created_tensorboard.name,
+ credentials=credentials,
+ )
+
+ def update(
+ self,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
+ encryption_spec_key_name: Optional[str] = None,
+ ) -> "Tensorboard":
+ """Updates an existing tensorboard.
+
+ Example Usage:
+
+ tb = aiplatform.Tensorboard(tensorboard_name='123456')
+ tb.update(
+ display_name='update my display name',
+ description='update my description',
+ )
+
+ Args:
+ display_name (str):
+ Optional. User-defined name of the Tensorboard.
+ The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+ description (str):
+ Optional. Description of this Tensorboard.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ encryption_spec_key_name (str):
+ Optional. Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the tensorboard. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this Tensorboard and all sub-resources of this Tensorboard will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+
+ Returns:
+ Tensorboard: The managed tensorboard resource.
+ """
+ update_mask = list()
+
+ if display_name:
+ utils.validate_display_name(display_name)
+ update_mask.append("display_name")
+
+ if description:
+ update_mask.append("description")
+
+ if labels:
+ utils.validate_labels(labels)
+ update_mask.append("labels")
+
+ encryption_spec = None
+ if encryption_spec_key_name:
+ encryption_spec = initializer.global_config.get_encryption_spec(
+ encryption_spec_key_name=encryption_spec_key_name,
+ )
+ update_mask.append("encryption_spec")
+
+ update_mask = field_mask_pb2.FieldMask(paths=update_mask)
+
+ gapic_tensorboard = gca_tensorboard.Tensorboard(
+ name=self.resource_name,
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ encryption_spec=encryption_spec,
+ )
+
+ _LOGGER.log_action_start_against_resource(
+ "Updating",
+ "tensorboard",
+ self,
+ )
+
+ update_tensorboard_lro = self.api_client.update_tensorboard(
+ tensorboard=gapic_tensorboard,
+ update_mask=update_mask,
+ metadata=request_metadata,
+ )
+
+ _LOGGER.log_action_started_against_resource_with_lro(
+ "Update", "tensorboard", self.__class__, update_tensorboard_lro
+ )
+
+ update_tensorboard_lro.result()
+
+ _LOGGER.log_action_completed_against_resource("tensorboard", "updated", self)
+
+ return self
+
+
+class TensorboardExperiment(_TensorboardServiceResource):
+ """Managed tensorboard resource for Vertex AI."""
+
+ _resource_noun = "experiments"
+ _getter_method = "get_tensorboard_experiment"
+ _list_method = "list_tensorboard_experiments"
+ _delete_method = "delete_tensorboard_experiment"
+ _parse_resource_name_method = "parse_tensorboard_experiment_path"
+ _format_resource_name_method = "tensorboard_experiment_path"
+
+ def __init__(
+ self,
+ tensorboard_experiment_name: str,
+ tensorboard_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing tensorboard experiment given a tensorboard experiment name or ID.
+
+ Example Usage:
+
+ tb_exp = aiplatform.TensorboardExperiment(
+ tensorboard_experiment_name= "projects/123/locations/us-central1/tensorboards/456/experiments/678"
+ )
+
+ tb_exp = aiplatform.TensorboardExperiment(
+ tensorboard_experiment_name= "678"
+ tensorboard_id = "456"
+ )
+
+ Args:
+ tensorboard_experiment_name (str):
+ Required. A fully-qualified tensorboard experiment resource name or resource ID.
+ Example: "projects/123/locations/us-central1/tensorboards/456/experiments/678" or
+ "678" when tensorboard_id is passed and project and location are initialized or passed.
+ tensorboard_id (str):
+ Optional. A tensorboard resource ID.
+ project (str):
+ Optional. Project to retrieve tensorboard from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve tensorboard from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Tensorboard. Overrides
+ credentials set in aiplatform.init.
+ """
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=tensorboard_experiment_name,
+ )
+ self._gca_resource = self._get_gca_resource(
+ resource_name=tensorboard_experiment_name,
+ parent_resource_name_fields={Tensorboard._resource_noun: tensorboard_id}
+ if tensorboard_id
+ else tensorboard_id,
+ )
+
+ @classmethod
+ def create(
+ cls,
+ tensorboard_experiment_id: str,
+ tensorboard_name: str,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Sequence[Tuple[str, str]] = (),
+ create_request_timeout: Optional[float] = None,
+ ) -> "TensorboardExperiment":
+ """Creates a new TensorboardExperiment.
+
+ Example Usage:
+
+ tb_exp = aiplatform.TensorboardExperiment.create(
+ tensorboard_experiment_id='my-experiment'
+ tensorboard_id='456'
+ display_name='my display name',
+ description='my description',
+ labels={
+ 'key1': 'value1',
+ 'key2': 'value2'
+ }
+ )
+
+ Args:
+ tensorboard_experiment_id (str):
+ Required. The ID to use for the Tensorboard experiment,
+ which will become the final component of the Tensorboard
+ experiment's resource name.
+
+ This value should be 1-128 characters, and valid
+ characters are /[a-z][0-9]-/.
+
+ This corresponds to the ``tensorboard_experiment_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ tensorboard_name (str):
+ Required. The resource name or ID of the Tensorboard to create
+ the TensorboardExperiment in. Format of resource name:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+ display_name (str):
+ Optional. The user-defined name of the Tensorboard Experiment.
+ The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+ description (str):
+ Optional. Description of this Tensorboard Experiment.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
+ project (str):
+ Optional. Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ Returns:
+ TensorboardExperiment: The TensorboardExperiment resource.
+ """
+
+ if display_name:
+ utils.validate_display_name(display_name)
+
+ if labels:
+ utils.validate_labels(labels)
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_name,
+ resource_noun=Tensorboard._resource_noun,
+ parse_resource_name_method=Tensorboard._parse_resource_name,
+ format_resource_name_method=Tensorboard._format_resource_name,
+ project=project,
+ location=location,
+ )
+
+ gapic_tensorboard_experiment = gca_tensorboard_experiment.TensorboardExperiment(
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ )
+
+ _LOGGER.log_create_with_lro(cls)
+
+ tensorboard_experiment = api_client.create_tensorboard_experiment(
+ parent=parent,
+ tensorboard_experiment=gapic_tensorboard_experiment,
+ tensorboard_experiment_id=tensorboard_experiment_id,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_complete(cls, tensorboard_experiment, "tb experiment")
+
+ return cls(
+ tensorboard_experiment_name=tensorboard_experiment.name,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def list(
+ cls,
+ tensorboard_name: str,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["TensorboardExperiment"]:
+ """List TensorboardExperiemnts in a Tensorboard resource.
+
+ Example Usage:
+
+ aiplatform.TensorboardExperiment.list(
+ tensorboard_name='projects/my-project/locations/us-central1/tensorboards/123'
+ )
+
+ Args:
+ tensorboard_name(str):
+ Required. The resource name or resource ID of the
+ Tensorboard to list
+ TensorboardExperiments. Format, if resource name:
+ 'projects/{project}/locations/{location}/tensorboards/{tensorboard}'
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ List[TensorboardExperiment] - A list of TensorboardExperiments
+ """
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_name,
+ resource_noun=Tensorboard._resource_noun,
+ parse_resource_name_method=Tensorboard._parse_resource_name,
+ format_resource_name_method=Tensorboard._format_resource_name,
+ project=project,
+ location=location,
+ )
+
+ return super()._list(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=parent,
+ )
+
+
+class TensorboardRun(_TensorboardServiceResource):
+ """Managed tensorboard resource for Vertex AI."""
+
+ _resource_noun = "runs"
+ _getter_method = "get_tensorboard_run"
+ _list_method = "list_tensorboard_runs"
+ _delete_method = "delete_tensorboard_run"
+ _parse_resource_name_method = "parse_tensorboard_run_path"
+ _format_resource_name_method = "tensorboard_run_path"
+
+ def __init__(
+ self,
+ tensorboard_run_name: str,
+ tensorboard_id: Optional[str] = None,
+ tensorboard_experiment_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing tensorboard run given a tensorboard run name or ID.
+
+ Example Usage:
+
+ tb_run = aiplatform.TensorboardRun(
+ tensorboard_run_name= "projects/123/locations/us-central1/tensorboards/456/experiments/678/run/8910"
+ )
+
+ tb_run = aiplatform.TensorboardRun(
+ tensorboard_run_name= "8910",
+ tensorboard_id = "456",
+ tensorboard_experiment_id = "678"
+ )
+
+ Args:
+ tensorboard_run_name (str):
+ Required. A fully-qualified tensorboard run resource name or resource ID.
+ Example: "projects/123/locations/us-central1/tensorboards/456/experiments/678/runs/8910" or
+ "8910" when tensorboard_id and tensorboard_experiment_id are passed
+ and project and location are initialized or passed.
+ tensorboard_id (str):
+ Optional. A tensorboard resource ID.
+ tensorboard_experiment_id (str):
+ Optional. A tensorboard experiment resource ID.
+ project (str):
+ Optional. Project to retrieve tensorboard from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve tensorboard from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Tensorboard. Overrides
+ credentials set in aiplatform.init.
+ Raises:
+ ValueError: if only one of tensorboard_id or tensorboard_experiment_id is provided.
+ """
+ if bool(tensorboard_id) != bool(tensorboard_experiment_id):
+ raise ValueError(
+ "Both tensorboard_id and tensorboard_experiment_id must be provided or neither should be provided."
+ )
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=tensorboard_run_name,
+ )
+ self._gca_resource = self._get_gca_resource(
+ resource_name=tensorboard_run_name,
+ parent_resource_name_fields={
+ Tensorboard._resource_noun: tensorboard_id,
+ TensorboardExperiment._resource_noun: tensorboard_experiment_id,
+ }
+ if tensorboard_id
+ else tensorboard_id,
+ )
+
+ self._time_series_display_name_to_id_mapping = (
+ self._get_time_series_display_name_to_id_mapping()
+ )
+
+ @classmethod
+ def create(
+ cls,
+ tensorboard_run_id: str,
+ tensorboard_experiment_name: str,
+ tensorboard_id: Optional[str] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ request_metadata: Sequence[Tuple[str, str]] = (),
+ create_request_timeout: Optional[float] = None,
+ ) -> "TensorboardRun":
+ """Creates a new tensorboard run.
+
+ Example Usage:
+
+ tb_run = aiplatform.TensorboardRun.create(
+ tensorboard_run_id='my-run'
+ tensorboard_experiment_name='my-experiment'
+ tensorboard_id='456'
+ display_name='my display name',
+ description='my description',
+ labels={
+ 'key1': 'value1',
+ 'key2': 'value2'
+ }
+ )
+
+ Args:
+ tensorboard_run_id (str):
+ Required. The ID to use for the Tensorboard run, which
+ will become the final component of the Tensorboard run's
+ resource name.
+
+ This value should be 1-128 characters, and valid:
+ characters are /[a-z][0-9]-/.
+ tensorboard_experiment_name (str):
+ Required. The resource name or ID of the TensorboardExperiment
+ to create the TensorboardRun in. Resource name format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}``
+
+ If resource ID is provided then tensorboard_id must be provided.
+ tensorboard_id (str):
+ Optional. The resource ID of the Tensorboard to create the TensorboardRun in.
+ display_name (str):
+ Optional. The user-defined name of the Tensorboard Run.
+ This value must be unique among all TensorboardRuns belonging to the
+ same parent TensorboardExperiment.
+
+ If not provided tensorboard_run_id will be used.
+ description (str):
+ Optional. Description of this Tensorboard Run.
+ labels (Dict[str, str]):
+ Optional. Labels with user-defined metadata to organize your Tensorboards.
+ Label keys and values can be no longer than 64 characters
+ (Unicode codepoints), can only contain lowercase letters, numeric
+ characters, underscores and dashes. International characters are allowed.
+ No more than 64 user labels can be associated with one Tensorboard
+ (System labels are excluded).
+ See https://goo.gl/xmQnxf for more information and examples of labels.
+ System reserved label keys are prefixed with "aiplatform.googleapis.com/"
+ and are immutable.
+ project (str):
+ Optional. Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+ request_metadata (Sequence[Tuple[str, str]]):
+ Optional. Strings which should be sent along with the request as metadata.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ Returns:
+ TensorboardRun: The TensorboardRun resource.
+ """
+ if display_name:
+ utils.validate_display_name(display_name)
+
+ if labels:
+ utils.validate_labels(labels)
+
+ display_name = display_name or tensorboard_run_id
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_experiment_name,
+ resource_noun=TensorboardExperiment._resource_noun,
+ parse_resource_name_method=TensorboardExperiment._parse_resource_name,
+ format_resource_name_method=TensorboardExperiment._format_resource_name,
+ parent_resource_name_fields={Tensorboard._resource_noun: tensorboard_id},
+ project=project,
+ location=location,
+ )
+
+ gapic_tensorboard_run = gca_tensorboard_run.TensorboardRun(
+ display_name=display_name,
+ description=description,
+ labels=labels,
+ )
+
+ _LOGGER.log_create_with_lro(cls)
+
+ tensorboard_run = api_client.create_tensorboard_run(
+ parent=parent,
+ tensorboard_run=gapic_tensorboard_run,
+ tensorboard_run_id=tensorboard_run_id,
+ metadata=request_metadata,
+ timeout=create_request_timeout,
+ )
+
+ _LOGGER.log_create_complete(cls, tensorboard_run, "tb_run")
+
+ return cls(
+ tensorboard_run_name=tensorboard_run.name,
+ credentials=credentials,
+ )
+
+ @classmethod
+ def list(
+ cls,
+ tensorboard_experiment_name: str,
+ tensorboard_id: Optional[str] = None,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["TensorboardRun"]:
+ """List all instances of TensorboardRun in TensorboardExperiment.
+
+ Example Usage:
+
+ aiplatform.TensorboardRun.list(
+ tensorboard_experiment_name='projects/my-project/locations/us-central1/tensorboards/123/experiments/456'
+ )
+
+ Args:
+ tensorboard_experiment_name (str):
+ Required. The resource name or resource ID of the
+ TensorboardExperiment to list
+ TensorboardRun. Format, if resource name:
+ 'projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}'
+
+ If resource ID is provided then tensorboard_id must be provided.
+ tensorboard_id (str):
+ Optional. The resource ID of the Tensorboard that contains the TensorboardExperiment
+ to list TensorboardRun.
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ List[TensorboardRun] - A list of TensorboardRun
+ """
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_experiment_name,
+ resource_noun=TensorboardExperiment._resource_noun,
+ parse_resource_name_method=TensorboardExperiment._parse_resource_name,
+ format_resource_name_method=TensorboardExperiment._format_resource_name,
+ parent_resource_name_fields={Tensorboard._resource_noun: tensorboard_id},
+ project=project,
+ location=location,
+ )
+
+ tensorboard_runs = super()._list(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=parent,
+ )
+
+ for tensorboard_run in tensorboard_runs:
+ tensorboard_run._sync_time_series_display_name_to_id_mapping()
+
+ return tensorboard_runs
+
+ def write_tensorboard_scalar_data(
+ self,
+ time_series_data: Dict[str, float],
+ step: int,
+ wall_time: Optional[timestamp_pb2.Timestamp] = None,
+ ):
+ """Writes tensorboard scalar data to this run.
+
+ Args:
+ time_series_data (Dict[str, float]):
+ Required. Dictionary of where keys are TensorboardTimeSeries display name and values are the scalar value..
+ step (int):
+ Required. Step index of this data point within the run.
+ wall_time (timestamp_pb2.Timestamp):
+ Optional. Wall clock timestamp when this data point is
+ generated by the end user.
+
+ If not provided, this will be generated based on the value from time.time()
+ """
+
+ if not wall_time:
+ wall_time = utils.get_timestamp_proto()
+
+ ts_data = []
+
+ if any(
+ key not in self._time_series_display_name_to_id_mapping
+ for key in time_series_data.keys()
+ ):
+ self._sync_time_series_display_name_to_id_mapping()
+
+ for display_name, value in time_series_data.items():
+ time_series_id = self._time_series_display_name_to_id_mapping.get(
+ display_name
+ )
+
+ if not time_series_id:
+ raise RuntimeError(
+ f"TensorboardTimeSeries with display name {display_name} has not been created in TensorboardRun {self.resource_name}."
+ )
+
+ ts_data.append(
+ gca_tensorboard_data.TimeSeriesData(
+ tensorboard_time_series_id=time_series_id,
+ value_type=gca_tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR,
+ values=[
+ gca_tensorboard_data.TimeSeriesDataPoint(
+ scalar=gca_tensorboard_data.Scalar(value=value),
+ wall_time=wall_time,
+ step=step,
+ )
+ ],
+ )
+ )
+
+ self.api_client.write_tensorboard_run_data(
+ tensorboard_run=self.resource_name, time_series_data=ts_data
+ )
+
+ def _get_time_series_display_name_to_id_mapping(self) -> Dict[str, str]:
+ """Returns a mapping of the TimeSeries display names to resource IDs for this Run.
+
+ Returns:
+ Dict[str, str] - Dictionary mapping TensorboardTimeSeries display names to
+ resource IDs of TensorboardTimeSeries in this TensorboardRun."""
+ time_series = TensorboardTimeSeries.list(
+ tensorboard_run_name=self.resource_name, credentials=self.credentials
+ )
+
+ return {ts.display_name: ts.name for ts in time_series}
+
+ def _sync_time_series_display_name_to_id_mapping(self):
+ """Updates the local map of TimeSeries diplay name to resource ID."""
+ self._time_series_display_name_to_id_mapping = (
+ self._get_time_series_display_name_to_id_mapping()
+ )
+
+ def create_tensorboard_time_series(
+ self,
+ display_name: str,
+ value_type: Union[
+ gca_tensorboard_time_series.TensorboardTimeSeries.ValueType, str
+ ] = "SCALAR",
+ plugin_name: str = "scalars",
+ plugin_data: Optional[bytes] = None,
+ description: Optional[str] = None,
+ ) -> "TensorboardTimeSeries":
+ """Creates a new tensorboard time series.
+
+ Example Usage:
+
+ tb_ts = tensorboard_run.create_tensorboard_time_series(
+ display_name='my display name',
+ tensorboard_run_name='my-run'
+ tensorboard_id='456'
+ tensorboard_experiment_id='my-experiment'
+ description='my description',
+ labels={
+ 'key1': 'value1',
+ 'key2': 'value2'
+ }
+ )
+
+ Args:
+ display_name (str):
+ Optional. User provided name of this
+ TensorboardTimeSeries. This value should be
+ unique among all TensorboardTimeSeries resources
+ belonging to the same TensorboardRun resource
+ (parent resource).
+ value_type (Union[gca_tensorboard_time_series.TensorboardTimeSeries.ValueType, str]):
+ Optional. Type of TensorboardTimeSeries value. One of 'SCALAR', 'TENSOR', 'BLOB_SEQUENCE'.
+ plugin_name (str):
+ Optional. Name of the plugin this time series pertain to. Such as Scalar, Tensor, Blob.
+ plugin_data (bytes):
+ Optional. Data of the current plugin, with the size limited to 65KB.
+ description (str):
+ Optional. Description of this TensorboardTimeseries.
+ Returns:
+ TensorboardTimeSeries: The TensorboardTimeSeries resource.
+ """
+
+ tb_time_series = TensorboardTimeSeries.create(
+ display_name=display_name,
+ tensorboard_run_name=self.resource_name,
+ value_type=value_type,
+ plugin_name=plugin_name,
+ plugin_data=plugin_data,
+ description=description,
+ credentials=self.credentials,
+ )
+
+ self._time_series_display_name_to_id_mapping[
+ tb_time_series.display_name
+ ] = tb_time_series.name
+
+ return tb_time_series
+
+ def read_time_series_data(self) -> Dict[str, gca_tensorboard_data.TimeSeriesData]:
+ """Read the time series data of this run.
+
+ ```
+ time_series_data = tensorboard_run.read_time_series_data()
+
+ print(time_series_data['loss'].values[-1].scalar.value)
+ ```
+
+ Returns:
+ Dictionary of time series metric id to TimeSeriesData.
+ """
+ self._sync_time_series_display_name_to_id_mapping()
+
+ resource_name_parts = self._parse_resource_name(self.resource_name)
+ inverted_mapping = {
+ resource_id: display_name
+ for display_name, resource_id in self._time_series_display_name_to_id_mapping.items()
+ }
+
+ time_series_resource_names = [
+ TensorboardTimeSeries._format_resource_name(
+ time_series=resource_id, **resource_name_parts
+ )
+ for resource_id in inverted_mapping.keys()
+ ]
+
+ resource_name_parts.pop("experiment")
+ resource_name_parts.pop("run")
+
+ tensorboard_resource_name = Tensorboard._format_resource_name(
+ **resource_name_parts
+ )
+
+ read_response = self.api_client.batch_read_tensorboard_time_series_data(
+ request=gca_tensorboard_service.BatchReadTensorboardTimeSeriesDataRequest(
+ tensorboard=tensorboard_resource_name,
+ time_series=time_series_resource_names,
+ )
+ )
+
+ return {
+ inverted_mapping[data.tensorboard_time_series_id]: data
+ for data in read_response.time_series_data
+ }
+
+
+class TensorboardTimeSeries(_TensorboardServiceResource):
+ """Managed tensorboard resource for Vertex AI."""
+
+ _resource_noun = "timeSeries"
+ _getter_method = "get_tensorboard_time_series"
+ _list_method = "list_tensorboard_time_series"
+ _delete_method = "delete_tensorboard_time_series"
+ _parse_resource_name_method = "parse_tensorboard_time_series_path"
+ _format_resource_name_method = "tensorboard_time_series_path"
+
+ def __init__(
+ self,
+ tensorboard_time_series_name: str,
+ tensorboard_id: Optional[str] = None,
+ tensorboard_experiment_id: Optional[str] = None,
+ tensorboard_run_id: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ):
+ """Retrieves an existing tensorboard time series given a tensorboard time series name or ID.
+
+ Example Usage:
+
+ tb_ts = aiplatform.TensorboardTimeSeries(
+ tensorboard_time_series_name="projects/123/locations/us-central1/tensorboards/456/experiments/789/run/1011/timeSeries/mse"
+ )
+
+ tb_ts = aiplatform.TensorboardTimeSeries(
+ tensorboard_time_series_name= "mse",
+ tensorboard_id = "456",
+ tensorboard_experiment_id = "789"
+ tensorboard_run_id = "1011"
+ )
+
+ Args:
+ tensorboard_time_series_name (str):
+ Required. A fully-qualified tensorboard time series resource name or resource ID.
+ Example: "projects/123/locations/us-central1/tensorboards/456/experiments/789/run/1011/timeSeries/mse" or
+ "mse" when tensorboard_id, tensorboard_experiment_id, tensorboard_run_id are passed
+ and project and location are initialized or passed.
+ tensorboard_id (str):
+ Optional. A tensorboard resource ID.
+ tensorboard_experiment_id (str):
+ Optional. A tensorboard experiment resource ID.
+ tensorboard_run_id (str):
+ Optional. A tensorboard run resource ID.
+ project (str):
+ Optional. Project to retrieve tensorboard from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve tensorboard from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve this Tensorboard. Overrides
+ credentials set in aiplatform.init.
+ Raises:
+ ValueError: if only one of tensorboard_id or tensorboard_experiment_id is provided.
+ """
+ if not (
+ bool(tensorboard_id)
+ == bool(tensorboard_experiment_id)
+ == bool(tensorboard_run_id)
+ ):
+ raise ValueError(
+ "tensorboard_id, tensorboard_experiment_id, tensorboard_run_id must all be provided or none should be provided."
+ )
+
+ super().__init__(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=tensorboard_time_series_name,
+ )
+ self._gca_resource = self._get_gca_resource(
+ resource_name=tensorboard_time_series_name,
+ parent_resource_name_fields={
+ Tensorboard._resource_noun: tensorboard_id,
+ TensorboardExperiment._resource_noun: tensorboard_experiment_id,
+ TensorboardRun._resource_noun: tensorboard_run_id,
+ }
+ if tensorboard_id
+ else tensorboard_id,
+ )
+
+ @classmethod
+ def create(
+ cls,
+ display_name: str,
+ tensorboard_run_name: str,
+ tensorboard_id: Optional[str] = None,
+ tensorboard_experiment_id: Optional[str] = None,
+ value_type: Union[
+ gca_tensorboard_time_series.TensorboardTimeSeries.ValueType, str
+ ] = "SCALAR",
+ plugin_name: str = "scalars",
+ plugin_data: Optional[bytes] = None,
+ description: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "TensorboardTimeSeries":
+ """Creates a new tensorboard time series.
+
+ Example Usage:
+
+ tb_ts = aiplatform.TensorboardTimeSeries.create(
+ display_name='my display name',
+ tensorboard_run_name='my-run'
+ tensorboard_id='456'
+ tensorboard_experiment_id='my-experiment'
+ description='my description',
+ labels={
+ 'key1': 'value1',
+ 'key2': 'value2'
+ }
+ )
+
+ Args:
+ display_name (str):
+ Optional. User provided name of this
+ TensorboardTimeSeries. This value should be
+ unique among all TensorboardTimeSeries resources
+ belonging to the same TensorboardRun resource
+ (parent resource).
+ tensorboard_run_name (str):
+ Required. The resource name or ID of the TensorboardRun
+ to create the TensorboardTimeseries in. Resource name format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}``
+
+ If resource ID is provided then tensorboard_id and tensorboard_experiment_id must be provided.
+ tensorboard_id (str):
+ Optional. The resource ID of the Tensorboard to create the TensorboardTimeSeries in.
+ tensorboard_experiment_id (str):
+ Optional. The ID of the TensorboardExperiment to create the TensorboardTimeSeries in.
+ value_type (Union[gca_tensorboard_time_series.TensorboardTimeSeries.ValueType, str]):
+ Optional. Type of TensorboardTimeSeries value. One of 'SCALAR', 'TENSOR', 'BLOB_SEQUENCE'.
+ plugin_name (str):
+ Optional. Name of the plugin this time series pertain to.
+ plugin_data (bytes):
+ Optional. Data of the current plugin, with the size limited to 65KB.
+ description (str):
+ Optional. Description of this TensorboardTimeseries.
+ project (str):
+ Optional. Project to upload this model to. Overrides project set in
+ aiplatform.init.
+ location (str):
+ Optional. Location to upload this model to. Overrides location set in
+ aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ TensorboardTimeSeries: The TensorboardTimeSeries resource.
+ """
+
+ if isinstance(value_type, str):
+ value_type = getattr(
+ gca_tensorboard_time_series.TensorboardTimeSeries.ValueType, value_type
+ )
+
+ api_client = cls._instantiate_client(location=location, credentials=credentials)
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_run_name,
+ resource_noun=TensorboardRun._resource_noun,
+ parse_resource_name_method=TensorboardRun._parse_resource_name,
+ format_resource_name_method=TensorboardRun._format_resource_name,
+ parent_resource_name_fields={
+ Tensorboard._resource_noun: tensorboard_id,
+ TensorboardExperiment._resource_noun: tensorboard_experiment_id,
+ },
+ project=project,
+ location=location,
+ )
+
+ gapic_tensorboard_time_series = (
+ gca_tensorboard_time_series.TensorboardTimeSeries(
+ display_name=display_name,
+ description=description,
+ value_type=value_type,
+ plugin_name=plugin_name,
+ plugin_data=plugin_data,
+ )
+ )
+
+ _LOGGER.log_create_with_lro(cls)
+
+ tensorboard_time_series = api_client.create_tensorboard_time_series(
+ parent=parent, tensorboard_time_series=gapic_tensorboard_time_series
+ )
+
+ _LOGGER.log_create_complete(cls, tensorboard_time_series, "tb_time_series")
+
+ self = cls._empty_constructor(
+ project=project, location=location, credentials=credentials
+ )
+ self._gca_resource = tensorboard_time_series
+
+ return self
+
+ @classmethod
+ def list(
+ cls,
+ tensorboard_run_name: str,
+ tensorboard_id: Optional[str] = None,
+ tensorboard_experiment_id: Optional[str] = None,
+ filter: Optional[str] = None,
+ order_by: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> List["TensorboardTimeSeries"]:
+ """List all instances of TensorboardTimeSeries in TensorboardRun.
+
+ Example Usage:
+
+ aiplatform.TensorboardTimeSeries.list(
+ tensorboard_run_name='projects/my-project/locations/us-central1/tensorboards/123/experiments/my-experiment/runs/my-run'
+ )
+
+ Args:
+ tensorboard_run_name (str):
+ Required. The resource name or ID of the TensorboardRun
+ to list the TensorboardTimeseries from. Resource name format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}``
+
+ If resource ID is provided then tensorboard_id and tensorboard_experiment_id must be provided.
+ tensorboard_id (str):
+ Optional. The resource ID of the Tensorboard to list the TensorboardTimeSeries from.
+ tensorboard_experiment_id (str):
+ Optional. The ID of the TensorboardExperiment to list the TensorboardTimeSeries from.
+ filter (str):
+ Optional. An expression for filtering the results of the request.
+ For field names both snake_case and camelCase are supported.
+ order_by (str):
+ Optional. A comma-separated list of fields to order by, sorted in
+ ascending order. Use "desc" after a field name for descending.
+ Supported fields: `display_name`, `create_time`, `update_time`
+ project (str):
+ Optional. Project to retrieve list from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional. Location to retrieve list from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to retrieve list. Overrides
+ credentials set in aiplatform.init.
+ Returns:
+ List[TensorboardTimeSeries] - A list of TensorboardTimeSeries
+ """
+
+ parent = utils.full_resource_name(
+ resource_name=tensorboard_run_name,
+ resource_noun=TensorboardRun._resource_noun,
+ parse_resource_name_method=TensorboardRun._parse_resource_name,
+ format_resource_name_method=TensorboardRun._format_resource_name,
+ parent_resource_name_fields={
+ Tensorboard._resource_noun: tensorboard_id,
+ TensorboardExperiment._resource_noun: tensorboard_experiment_id,
+ },
+ project=project,
+ location=location,
+ )
+
+ return super()._list(
+ filter=filter,
+ order_by=order_by,
+ project=project,
+ location=location,
+ credentials=credentials,
+ parent=parent,
+ )
diff --git a/google/cloud/aiplatform/tensorboard/uploader.py b/google/cloud/aiplatform/tensorboard/uploader.py
index 57dcbedf60..c9926eb18a 100644
--- a/google/cloud/aiplatform/tensorboard/uploader.py
+++ b/google/cloud/aiplatform/tensorboard/uploader.py
@@ -15,13 +15,22 @@
# limitations under the License.
#
"""Uploads a TensorBoard logdir to TensorBoard.gcp."""
-import contextlib
+import abc
+from collections import defaultdict
import functools
-import json
+import logging
import os
import time
import re
-from typing import Callable, Dict, FrozenSet, Generator, Iterable, Optional, Tuple
+from typing import (
+ Dict,
+ FrozenSet,
+ Generator,
+ Iterable,
+ Optional,
+ ContextManager,
+ Tuple,
+)
import uuid
import grpc
@@ -47,26 +56,17 @@
from google.api_core import exceptions
from google.cloud import storage
-from google.cloud.aiplatform.compat.services import tensorboard_service_client_v1beta1
-from google.cloud.aiplatform.compat.types import (
- tensorboard_data_v1beta1 as tensorboard_data,
-)
-from google.cloud.aiplatform.compat.types import (
- tensorboard_experiment_v1beta1 as tensorboard_experiment,
-)
-from google.cloud.aiplatform.compat.types import (
- tensorboard_run_v1beta1 as tensorboard_run,
-)
-from google.cloud.aiplatform.compat.types import (
- tensorboard_service_v1beta1 as tensorboard_service,
-)
-from google.cloud.aiplatform.compat.types import (
- tensorboard_time_series_v1beta1 as tensorboard_time_series,
-)
+from google.cloud.aiplatform.compat.services import tensorboard_service_client
+from google.cloud.aiplatform.compat.types import tensorboard_data
+from google.cloud.aiplatform.compat.types import tensorboard_experiment
+from google.cloud.aiplatform.compat.types import tensorboard_service
+from google.cloud.aiplatform.compat.types import tensorboard_time_series
+from google.cloud.aiplatform.tensorboard import uploader_utils
+from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
from google.protobuf import message
from google.protobuf import timestamp_pb2 as timestamp
-TensorboardServiceClient = tensorboard_service_client_v1beta1.TensorboardServiceClient
+TensorboardServiceClient = tensorboard_service_client.TensorboardServiceClient
# Minimum length of a logdir polling cycle in seconds. Shorter cycles will
# sleep to avoid spinning over the logdir, which isn't great for disks and can
@@ -83,7 +83,7 @@
_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 10
# Default maximum WriteTensorbordRunData request size in bytes.
-_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 24 * (2 ** 10) # 24KiB
+_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 128 * (2**10) # 128KiB
# Default minimum interval between initiating WriteTensorbordRunData RPCs in
# milliseconds.
@@ -94,16 +94,26 @@
_DEFAULT_MIN_BLOB_REQUEST_INTERVAL = 10
# Default maximum WriteTensorbordRunData request size in bytes.
-_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2 ** 10) # 512KiB
+_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2**10) # 512KiB
-_DEFAULT_MAX_BLOB_REQUEST_SIZE = 4 * (2 ** 20) - 256 * (2 ** 10) # 4MiB-256KiB
+_DEFAULT_MAX_BLOB_REQUEST_SIZE = 128 * (2**10) # 24KiB
# Default maximum tensor point size in bytes.
-_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2 ** 10) # 16KiB
+_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2**10) # 16KiB
-_DEFAULT_MAX_BLOB_SIZE = 10 * (2 ** 30) # 10GiB
+_DEFAULT_MAX_BLOB_SIZE = 10 * (2**30) # 10GiB
logger = tb_logging.get_logger()
+logger.setLevel(logging.WARNING)
+
+
+class RequestSender(object):
+ """A base class for additional request sender objects.
+
+ Currently just used for typing.
+ """
+
+ pass
class TensorBoardUploader(object):
@@ -178,6 +188,7 @@ def __init__(
self._logdir = logdir
self._allowed_plugins = frozenset(allowed_plugins)
self._run_name_prefix = run_name_prefix
+ self._is_brand_new_experiment = False
self._upload_limits = upload_limits
if not self._upload_limits:
@@ -200,11 +211,11 @@ def __init__(
)
self._upload_limits.max_blob_request_size = _DEFAULT_MAX_BLOB_REQUEST_SIZE
self._upload_limits.max_blob_size = _DEFAULT_MAX_BLOB_SIZE
-
self._description = description
self._verbosity = verbosity
self._one_shot = one_shot
- self._request_sender = None
+ self._dispatcher = None
+ self._additional_senders: Dict[str, uploader_utils.RequestSender] = {}
if logdir_poll_rate_limiter is None:
self._logdir_poll_rate_limiter = util.RateLimiter(
_MIN_LOGDIR_POLL_INTERVAL_SECS
@@ -248,8 +259,13 @@ def active_filter(secs):
self._logdir_loader = logdir_loader.LogdirLoader(
self._logdir, directory_loader_factory
)
+ self._logdir_loader_pre_create = logdir_loader.LogdirLoader(
+ self._logdir, directory_loader_factory
+ )
self._tracker = upload_tracker.UploadTracker(verbosity=self._verbosity)
+ self._create_additional_senders()
+
def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperiment:
"""Create an experiment or get an experiment.
@@ -271,6 +287,7 @@ def _create_or_get_experiment(self) -> tensorboard_experiment.TensorboardExperim
tensorboard_experiment=tb_experiment,
tensorboard_experiment_id=self._experiment_name,
)
+ self._is_brand_new_experiment = True
except exceptions.AlreadyExists:
logger.info("Creating experiment failed. Retrieving experiment.")
experiment_name = os.path.join(
@@ -284,6 +301,10 @@ def create_experiment(self):
experiment = self._create_or_get_experiment()
self._experiment = experiment
+ self._one_platform_resource_manager = uploader_utils.OnePlatformResourceManager(
+ self._experiment.name, self._api
+ )
+
self._request_sender = _BatchedRequestSender(
self._experiment.name,
self._api,
@@ -294,9 +315,48 @@ def create_experiment(self):
blob_rpc_rate_limiter=self._blob_rpc_rate_limiter,
blob_storage_bucket=self._blob_storage_bucket,
blob_storage_folder=self._blob_storage_folder,
+ one_platform_resource_manager=self._one_platform_resource_manager,
tracker=self._tracker,
)
+ # Update partials with experiment name
+ for sender in self._additional_senders.keys():
+ self._additional_senders[sender] = self._additional_senders[sender](
+ experiment_resource_name=self._experiment.name,
+ )
+
+ self._dispatcher = _Dispatcher(
+ request_sender=self._request_sender,
+ additional_senders=self._additional_senders,
+ )
+
+ def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
+ """Create any additional senders for non traditional event files.
+
+ Some items that are used for plugins do not process typical event files,
+ but need to be searched for and stored so that they can be used by the
+ plugin. If there are any items that cannot be searched for via the
+ `_BatchedRequestSender`, add them here.
+ """
+ if "profile" in self._allowed_plugins:
+ if not self._one_shot:
+ raise ValueError(
+ "Profile plugin currently only supported for one shot."
+ )
+ source_bucket = uploader_utils.get_source_bucket(self._logdir)
+
+ self._additional_senders["profile"] = functools.partial(
+ profile_uploader.ProfileRequestSender,
+ api=self._api,
+ upload_limits=self._upload_limits,
+ blob_rpc_rate_limiter=self._blob_rpc_rate_limiter,
+ blob_storage_bucket=self._blob_storage_bucket,
+ blob_storage_folder=self._blob_storage_folder,
+ source_bucket=source_bucket,
+ tracker=self._tracker,
+ logdir=self._logdir,
+ )
+
def get_experiment_resource_name(self):
return self._experiment.name
@@ -308,8 +368,19 @@ def start_uploading(self):
ExperimentNotFoundError: If the experiment is deleted during the
course of the upload.
"""
- if self._request_sender is None:
+ if self._dispatcher is None:
raise RuntimeError("Must call create_experiment() before start_uploading()")
+
+ if self._one_shot:
+ if self._is_brand_new_experiment:
+ self._pre_create_runs_and_time_series()
+ else:
+ logger.warning(
+ "Please consider uploading to a new experiment instead of "
+ "an existing one, as the former allows for better upload "
+ "performance."
+ )
+
while True:
self._logdir_poll_rate_limiter.tick()
self._upload_once()
@@ -321,6 +392,58 @@ def start_uploading(self):
"without any uploadable data" % self._logdir
)
+ def _pre_create_runs_and_time_series(self):
+ """
+ Iterates though the log dir to collect TensorboardRuns and
+ TensorboardTimeSeries that need to be created, and creates them in batch
+ to speed up uploading later on.
+ """
+ self._logdir_loader_pre_create.synchronize_runs()
+ run_to_events = self._logdir_loader_pre_create.get_run_events()
+ if self._run_name_prefix:
+ run_to_events = {
+ self._run_name_prefix + k: v for k, v in run_to_events.items()
+ }
+
+ run_names = []
+ run_tag_name_to_time_series_proto = {}
+ for (run_name, events) in run_to_events.items():
+ run_names.append(run_name)
+ for event in events:
+ _filter_graph_defs(event)
+ for value in event.summary.value:
+ metadata, is_valid = self._request_sender.get_metadata_and_validate(
+ run_name, value
+ )
+ if not is_valid:
+ continue
+ if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
+ value_type = (
+ tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
+ )
+ elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
+ value_type = (
+ tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR
+ )
+ elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
+ value_type = (
+ tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE
+ )
+
+ run_tag_name_to_time_series_proto[
+ (run_name, value.tag)
+ ] = tensorboard_time_series.TensorboardTimeSeries(
+ display_name=value.tag,
+ value_type=value_type,
+ plugin_name=metadata.plugin_data.plugin_name,
+ plugin_data=metadata.plugin_data.content,
+ )
+
+ self._one_platform_resource_manager.batch_create_runs(run_names)
+ self._one_platform_resource_manager.batch_create_time_series(
+ run_tag_name_to_time_series_proto
+ )
+
def _upload_once(self):
"""Runs one upload cycle, sending zero or more RPCs."""
logger.info("Starting an upload cycle")
@@ -336,19 +459,15 @@ def _upload_once(self):
self._run_name_prefix + k: v for k, v in run_to_events.items()
}
with self._tracker.send_tracker():
- self._request_sender.send_requests(run_to_events)
-
-
-class ExperimentNotFoundError(RuntimeError):
- pass
+ self._dispatcher.dispatch_requests(run_to_events)
class PermissionDeniedError(RuntimeError):
pass
-class ExistingResourceNotFoundError(RuntimeError):
- """Resource could not be created or retrieved."""
+class ExperimentNotFoundError(RuntimeError):
+ pass
class _OutOfSpaceError(Exception):
@@ -387,6 +506,7 @@ def __init__(
blob_rpc_rate_limiter: util.RateLimiter,
blob_storage_bucket: storage.Bucket,
blob_storage_folder: str,
+ one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
tracker: upload_tracker.UploadTracker,
):
"""Constructs _BatchedRequestSender for the given experiment resource.
@@ -404,6 +524,8 @@ def __init__(
Note the chunk stream is internally rate-limited by backpressure from
the server, so it is not a concern that we do not explicitly rate-limit
within the stream here.
+ one_platform_resource_manager: An instance of the One Platform
+ resource management class.
tracker: Upload tracker to track information about uploads.
"""
self._experiment_resource_name = experiment_resource_name
@@ -411,27 +533,26 @@ def __init__(
self._tag_metadata = {}
self._allowed_plugins = frozenset(allowed_plugins)
self._tracker = tracker
- self._run_to_request_sender: Dict[str, _ScalarBatchedRequestSender] = {}
- self._run_to_tensor_request_sender: Dict[str, _TensorBatchedRequestSender] = {}
- self._run_to_blob_request_sender: Dict[str, _BlobRequestSender] = {}
- self._run_to_run_resource: Dict[str, tensorboard_run.TensorboardRun] = {}
- self._scalar_request_sender_factory = functools.partial(
- _ScalarBatchedRequestSender,
+ self._one_platform_resource_manager = one_platform_resource_manager
+ self._scalar_request_sender = _ScalarBatchedRequestSender(
+ experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=rpc_rate_limiter,
max_request_size=upload_limits.max_scalar_request_size,
tracker=self._tracker,
+ one_platform_resource_manager=self._one_platform_resource_manager,
)
- self._tensor_request_sender_factory = functools.partial(
- _TensorBatchedRequestSender,
+ self._tensor_request_sender = _TensorBatchedRequestSender(
+ experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=tensor_rpc_rate_limiter,
max_request_size=upload_limits.max_tensor_request_size,
max_tensor_point_size=upload_limits.max_tensor_point_size,
tracker=self._tracker,
+ one_platform_resource_manager=self._one_platform_resource_manager,
)
- self._blob_request_sender_factory = functools.partial(
- _BlobRequestSender,
+ self._blob_request_sender = _BlobRequestSender(
+ experiment_resource_id=experiment_resource_name,
api=api,
rpc_rate_limiter=blob_rpc_rate_limiter,
max_blob_request_size=upload_limits.max_blob_request_size,
@@ -439,10 +560,14 @@ def __init__(
blob_storage_bucket=blob_storage_bucket,
blob_storage_folder=blob_storage_folder,
tracker=self._tracker,
+ one_platform_resource_manager=self._one_platform_resource_manager,
)
- def send_requests(
- self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]]
+ def send_request(
+ self,
+ run_name: str,
+ event: tf.compat.v1.Event,
+ value: tf.compat.v1.Summary.Value,
):
"""Accepts a stream of TF events and sends batched write RPCs.
@@ -450,127 +575,127 @@ def send_requests(
the type of data (Scalar vs Tensor vs Blob) being sent.
Args:
- run_to_events: Mapping from run name to generator of `tf.compat.v1.Event`
- values, as returned by `LogdirLoader.get_run_events`.
+ run_name: Name of the run retrieved by `LogdirLoader.get_run_events`
+ event: The `tf.compat.v1.Event` for the run
+ value: A single `tf.compat.v1.Summary.Value` from the event, where
+ there can be multiple values per event.
Raises:
RuntimeError: If no progress can be made because even a single
point is too large (say, due to a gigabyte-long tag name).
"""
+ metadata, is_valid = self.get_metadata_and_validate(run_name, value)
+ if not is_valid:
+ return
+ plugin_name = metadata.plugin_data.plugin_name
+ self._tracker.add_plugin_name(plugin_name)
- for (run_name, event, value) in self._run_values(run_to_events):
- time_series_key = (run_name, value.tag)
-
- # The metadata for a time series is memorized on the first event.
- # If later events arrive with a mismatching plugin_name, they are
- # ignored with a warning.
- metadata = self._tag_metadata.get(time_series_key)
- first_in_time_series = False
- if metadata is None:
- first_in_time_series = True
- metadata = value.metadata
- self._tag_metadata[time_series_key] = metadata
-
- plugin_name = metadata.plugin_data.plugin_name
- if value.HasField("metadata") and (
- plugin_name != value.metadata.plugin_data.plugin_name
- ):
- logger.warning(
- "Mismatching plugin names for %s. Expected %s, found %s.",
+ if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
+ self._scalar_request_sender.add_event(run_name, event, value, metadata)
+ elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
+ self._tensor_request_sender.add_event(run_name, event, value, metadata)
+ elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
+ self._blob_request_sender.add_event(run_name, event, value, metadata)
+
+ def flush(self):
+ """Flushes any events that have been stored."""
+ self._scalar_request_sender.flush()
+ self._tensor_request_sender.flush()
+ self._blob_request_sender.flush()
+
+ def get_metadata_and_validate(
+ self, run_name: str, value: tf.compat.v1.Summary.Value
+ ) -> Tuple[tf.compat.v1.SummaryMetadata, bool]:
+ """
+
+ :param run_name: Name of the run retrieved by
+ `LogdirLoader.get_run_events`
+ :param value: A single `tf.compat.v1.Summary.Value` from the event,
+ where there can be multiple values per event.
+ :return: (metadata, is_valid): a metadata derived from the value, and
+ whether the value itself is valid.
+ """
+
+ time_series_key = (run_name, value.tag)
+
+ # The metadata for a time series is memorized on the first event.
+ # If later events arrive with a mismatching plugin_name, they are
+ # ignored with a warning.
+ metadata = self._tag_metadata.get(time_series_key)
+ first_in_time_series = False
+ if metadata is None:
+ first_in_time_series = True
+ metadata = value.metadata
+ self._tag_metadata[time_series_key] = metadata
+
+ plugin_name = metadata.plugin_data.plugin_name
+ if value.HasField("metadata") and (
+ plugin_name != value.metadata.plugin_data.plugin_name
+ ):
+ logger.warning(
+ "Mismatching plugin names for %s. Expected %s, found %s.",
+ time_series_key,
+ metadata.plugin_data.plugin_name,
+ value.metadata.plugin_data.plugin_name,
+ )
+ return metadata, False
+ if plugin_name not in self._allowed_plugins:
+ if first_in_time_series:
+ logger.info(
+ "Skipping time series %r with unsupported plugin name %r",
time_series_key,
- metadata.plugin_data.plugin_name,
- value.metadata.plugin_data.plugin_name,
- )
- continue
- if plugin_name not in self._allowed_plugins:
- if first_in_time_series:
- logger.info(
- "Skipping time series %r with unsupported plugin name %r",
- time_series_key,
- plugin_name,
- )
- continue
- self._tracker.add_plugin_name(plugin_name)
- # If this is the first time we've seen this run create a new run resource
- # and an associated request sender.
- if run_name not in self._run_to_run_resource:
- self._create_or_get_run_resource(run_name)
- self._run_to_request_sender[
- run_name
- ] = self._scalar_request_sender_factory(
- self._run_to_run_resource[run_name].name
- )
- self._run_to_tensor_request_sender[
- run_name
- ] = self._tensor_request_sender_factory(
- self._run_to_run_resource[run_name].name
- )
- self._run_to_blob_request_sender[
- run_name
- ] = self._blob_request_sender_factory(
- self._run_to_run_resource[run_name].name
+ plugin_name,
)
+ return metadata, False
+ return metadata, True
- if metadata.data_class == summary_pb2.DATA_CLASS_SCALAR:
- self._run_to_request_sender[run_name].add_event(event, value, metadata)
- elif metadata.data_class == summary_pb2.DATA_CLASS_TENSOR:
- self._run_to_tensor_request_sender[run_name].add_event(
- event, value, metadata
- )
- elif metadata.data_class == summary_pb2.DATA_CLASS_BLOB_SEQUENCE:
- self._run_to_blob_request_sender[run_name].add_event(
- event, value, metadata
- )
- for scalar_request_sender in self._run_to_request_sender.values():
- scalar_request_sender.flush()
+class _Dispatcher(object):
+ """Dispatch the requests to the correct request senders."""
- for tensor_request_sender in self._run_to_tensor_request_sender.values():
- tensor_request_sender.flush()
+ def __init__(
+ self,
+ request_sender: _BatchedRequestSender,
+ additional_senders: Optional[Dict[str, uploader_utils.RequestSender]] = None,
+ ):
+ """Construct a _Dispatcher object for the TensorboardUploader.
- for blob_request_sender in self._run_to_blob_request_sender.values():
- blob_request_sender.flush()
+ Args:
+ request_sender: A `_BatchedRequestSender` for handling events.
+ additional_senders: A dictionary mapping a plugin name to additional
+ Senders.
+ """
+ self._request_sender = request_sender
- def _create_or_get_run_resource(self, run_name: str):
- """Creates a new Run Resource in current Tensorboard Experiment resource.
+ if not additional_senders:
+ additional_senders = {}
+ self._additional_senders = additional_senders
+
+ def _dispatch_additional_senders(
+ self,
+ run_name: str,
+ ):
+ """Dispatch events to any additional senders.
+
+ These senders process non traditional event files for a specific plugin
+ and use a send_request function to process events.
Args:
- run_name: The display name of this run.
+ run_name: String of current training run
"""
- tb_run = tensorboard_run.TensorboardRun()
- tb_run.display_name = run_name
- try:
- tb_run = self._api.create_tensorboard_run(
- parent=self._experiment_resource_name,
- tensorboard_run=tb_run,
- tensorboard_run_id=str(uuid.uuid4()),
- )
- except exceptions.InvalidArgument as e:
- # If the run name already exists then retrieve it
- if "already exist" in e.message:
- runs_pages = self._api.list_tensorboard_runs(
- parent=self._experiment_resource_name
- )
- for tb_run in runs_pages:
- if tb_run.display_name == run_name:
- break
-
- if tb_run.display_name != run_name:
- raise ExistingResourceNotFoundError(
- "Run with name %s already exists but is not resource list."
- % run_name
- )
- else:
- raise
-
- self._run_to_run_resource[run_name] = tb_run
+ for key, sender in self._additional_senders.items():
+ sender.send_request(run_name)
- def _run_values(
+ def dispatch_requests(
self, run_to_events: Dict[str, Generator[tf.compat.v1.Event, None, None]]
- ) -> Generator[
- Tuple[str, tf.compat.v1.Event, tf.compat.v1.Summary.Value], None, None
- ]:
- """Helper generator to create a single stream of work items.
+ ):
+ """Routes events to the appropriate sender.
+
+ Takes a mapping from strings to an event generator. The function routes
+ any events that should be handled by the `_BatchedRequestSender` and
+ non-traditional events that need to be handled differently, which are
+ stored as "_additional_senders". The `_request_sender` is then flushed
+ after all events are added.
Note that `dataclass_compat` may emit multiple variants of
the same event, for backwards compatibility. Thus this stream should
@@ -586,92 +711,17 @@ def _run_values(
Args:
run_to_events: Mapping from run name to generator of `tf.compat.v1.Event`
values, as returned by `LogdirLoader.get_run_events`.
-
- Yields:
- Tuple of run name, tf.compat.v1.Event, tf.compat.v1.Summary.Value per
- value.
"""
- # Note that this join in principle has deletion anomalies: if the input
- # stream contains runs with no events, or events with no values, we'll
- # lose that information. This is not a problem: we would need to prune
- # such data from the request anyway.
for (run_name, events) in run_to_events.items():
+ self._dispatch_additional_senders(run_name)
for event in events:
_filter_graph_defs(event)
for value in event.summary.value:
- yield (run_name, event, value)
-
-
-class _TimeSeriesResourceManager(object):
- """Helper class managing Time Series resources."""
-
- def __init__(self, run_resource_id: str, api: TensorboardServiceClient):
- """Constructor for _TimeSeriesResourceManager.
-
- Args:
- run_resource_id: The resource id for the run with the following format
- projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}
- api: TensorboardServiceStub
- """
- self._run_resource_id = run_resource_id
- self._api = api
- self._tag_to_time_series_proto: Dict[
- str, tensorboard_time_series.TensorboardTimeSeries
- ] = {}
-
- def get_or_create(
- self,
- tag_name: str,
- time_series_resource_creator: Callable[
- [], tensorboard_time_series.TensorboardTimeSeries
- ],
- ) -> tensorboard_time_series.TensorboardTimeSeries:
- """get a time series resource with given tag_name, and create a new one on
-
- OnePlatform if not present.
-
- Args:
- tag_name: The tag name of the time series in the Tensorboard log dir.
- time_series_resource_creator: A callable that produces a TimeSeries for
- creation.
- """
- if tag_name in self._tag_to_time_series_proto:
- return self._tag_to_time_series_proto[tag_name]
-
- time_series = time_series_resource_creator()
- time_series.display_name = tag_name
- try:
- time_series = self._api.create_tensorboard_time_series(
- parent=self._run_resource_id, tensorboard_time_series=time_series
- )
- except exceptions.InvalidArgument as e:
- # If the time series display name already exists then retrieve it
- if "already exist" in e.message:
- list_of_time_series = self._api.list_tensorboard_time_series(
- request=tensorboard_service.ListTensorboardTimeSeriesRequest(
- parent=self._run_resource_id,
- filter="display_name = {}".format(json.dumps(str(tag_name))),
- )
- )
- num = 0
- for ts in list_of_time_series:
- time_series = ts
- num += 1
- break
- if num != 1:
- raise ValueError(
- "More than one time series resource found with display_name: {}".format(
- tag_name
- )
- )
- else:
- raise
-
- self._tag_to_time_series_proto[tag_name] = time_series
- return time_series
+ self._request_sender.send_request(run_name, event, value)
+ self._request_sender.flush()
-class _ScalarBatchedRequestSender(object):
+class _BaseBatchedRequestSender(object):
"""Helper class for building requests that fit under a size limit.
This class accumulates a current request. `add_event(...)` may or may not
@@ -684,47 +734,49 @@ class _ScalarBatchedRequestSender(object):
def __init__(
self,
- run_resource_id: str,
+ experiment_resource_id: str,
api: TensorboardServiceClient,
rpc_rate_limiter: util.RateLimiter,
max_request_size: int,
tracker: upload_tracker.UploadTracker,
+ one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
):
- """Constructer for _ScalarBatchedRequestSender.
+ """Constructor for _BaseBatchedRequestSender.
Args:
- run_resource_id: The resource id for the run with the following format
- projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}
+ experiment_resource_id: The resource id for the experiment with the following format
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
api: TensorboardServiceStub
rpc_rate_limiter: until.RateLimiter to limit rate of this request sender
max_request_size: max number of bytes to send
tracker:
"""
- self._run_resource_id = run_resource_id
+ self._experiment_resource_id = experiment_resource_id
self._api = api
self._rpc_rate_limiter = rpc_rate_limiter
self._byte_budget_manager = _ByteBudgetManager(max_request_size)
self._tracker = tracker
+ self._one_platform_resource_manager = one_platform_resource_manager
# cache: map from Tensorboard tag to TimeSeriesData
# cleared whenever a new request is created
- self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {}
-
- self._time_series_resource_manager = _TimeSeriesResourceManager(
- self._run_resource_id, self._api
- )
+ self._run_to_tag_to_time_series_data: Dict[
+ str, Dict[str, tensorboard_data.TimeSeriesData]
+ ] = defaultdict(defaultdict)
self._new_request()
def _new_request(self):
"""Allocates a new request and refreshes the budget."""
- self._request = tensorboard_service.WriteTensorboardRunDataRequest()
- self._tag_to_time_series_data.clear()
+ self._request = tensorboard_service.WriteTensorboardExperimentDataRequest(
+ tensorboard_experiment=self._experiment_resource_id
+ )
+ self._run_to_tag_to_time_series_data.clear()
self._num_values = 0
- self._request.tensorboard_run = self._run_resource_id
self._byte_budget_manager.reset(self._request)
def add_event(
self,
+ run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
@@ -741,27 +793,32 @@ def add_event(
metadata: SummaryMetadata of the event.
"""
try:
- self._add_event_internal(event, value, metadata)
+ self._add_event_internal(run_name, event, value, metadata)
except _OutOfSpaceError:
self.flush()
# Try again. This attempt should never produce OutOfSpaceError
# because we just flushed.
try:
- self._add_event_internal(event, value, metadata)
+ self._add_event_internal(run_name, event, value, metadata)
except _OutOfSpaceError:
raise RuntimeError("add_event failed despite flush")
def _add_event_internal(
self,
+ run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
):
self._num_values += 1
- time_series_data_proto = self._tag_to_time_series_data.get(value.tag)
+ time_series_data_proto = self._run_to_tag_to_time_series_data[run_name].get(
+ value.tag
+ )
if time_series_data_proto is None:
- time_series_data_proto = self._create_time_series_data(value.tag, metadata)
- self._create_point(time_series_data_proto, event, value)
+ time_series_data_proto = self._create_time_series_data(
+ run_name, value.tag, metadata
+ )
+ self._create_point(run_name, time_series_data_proto, event, value, metadata)
def flush(self):
"""Sends the active request after removing empty runs and tags.
@@ -769,19 +826,34 @@ def flush(self):
Starts a new, empty active request.
"""
request = self._request
- request.time_series_data = list(self._tag_to_time_series_data.values())
- _prune_empty_time_series(request)
- if not request.time_series_data:
+ has_data = False
+ for (
+ run_name,
+ tag_to_time_series_data,
+ ) in self._run_to_tag_to_time_series_data.items():
+ r = tensorboard_service.WriteTensorboardRunDataRequest(
+ tensorboard_run=self._one_platform_resource_manager.get_run_resource_name(
+ run_name
+ )
+ )
+ r.time_series_data = list(tag_to_time_series_data.values())
+ _prune_empty_time_series(r)
+ if not r.time_series_data:
+ continue
+ request.write_run_data_requests.extend([r])
+ has_data = True
+
+ if not has_data:
return
self._rpc_rate_limiter.tick()
- with _request_logger(request):
- with self._tracker.scalars_tracker(self._num_values):
+ with uploader_utils.request_logger(request):
+ with self._get_tracker():
try:
- self._api.write_tensorboard_run_data(
- tensorboard_run=self._run_resource_id,
- time_series_data=request.time_series_data,
+ self._api.write_tensorboard_experiment_data(
+ tensorboard_experiment=request.tensorboard_experiment,
+ write_run_data_requests=request.write_run_data_requests,
)
except grpc.RpcError as e:
if (
@@ -794,7 +866,7 @@ def flush(self):
self._new_request()
def _create_time_series_data(
- self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata
+ self, run_name: str, tag_name: str, metadata: tf.compat.v1.SummaryMetadata
) -> tensorboard_data.TimeSeriesData:
"""Adds a time_series for the tag_name, if there's space.
@@ -808,52 +880,55 @@ def _create_time_series_data(
_OutOfSpaceError: If adding the tag would exceed the remaining
request budget.
"""
- time_series_data_proto = tensorboard_data.TimeSeriesData(
- tensorboard_time_series_id=self._time_series_resource_manager.get_or_create(
+ time_series_resource_name = (
+ self._one_platform_resource_manager.get_time_series_resource_name(
+ run_name,
tag_name,
lambda: tensorboard_time_series.TensorboardTimeSeries(
display_name=tag_name,
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR,
+ value_type=self._value_type,
plugin_name=metadata.plugin_data.plugin_name,
plugin_data=metadata.plugin_data.content,
),
- ).name.split("/")[-1],
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR,
+ )
+ )
+
+ time_series_data_proto = tensorboard_data.TimeSeriesData(
+ tensorboard_time_series_id=time_series_resource_name.split("/")[-1],
+ value_type=self._value_type,
)
- self._request.time_series_data.extend([time_series_data_proto])
self._byte_budget_manager.add_time_series(time_series_data_proto)
- self._tag_to_time_series_data[tag_name] = time_series_data_proto
+ self._run_to_tag_to_time_series_data[run_name][
+ tag_name
+ ] = time_series_data_proto
return time_series_data_proto
def _create_point(
self,
+ run_name: str,
time_series_proto: tensorboard_data.TimeSeriesData,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
+ metadata: tf.compat.v1.SummaryMetadata,
):
"""Adds a scalar point to the given tag, if there's space.
Args:
time_series_proto: TimeSeriesData proto to which to add a point.
event: Enclosing `Event` proto with the step and wall time data.
- value: Scalar `Summary.Value` proto with the actual scalar data.
+ value: `Summary.Value` proto.
+ metadata: SummaryMetadata of the event.
Raises:
_OutOfSpaceError: If adding the point would exceed the remaining
request budget.
"""
- scalar_proto = tensorboard_data.Scalar(
- value=tensor_util.make_ndarray(value.tensor).item()
- )
- point = tensorboard_data.TimeSeriesDataPoint(
- step=event.step,
- scalar=scalar_proto,
- wall_time=timestamp.Timestamp(
- seconds=int(event.wall_time),
- nanos=int(round((event.wall_time % 1) * 10 ** 9)),
- ),
- )
+ point = self._create_data_point(run_name, event, value, metadata)
+
+ if not self._validate(point, event, value):
+ return
+
time_series_proto.values.extend([point])
try:
self._byte_budget_manager.add_point(point)
@@ -861,193 +936,196 @@ def _create_point(
time_series_proto.values.pop()
raise
+ @abc.abstractmethod
+ def _get_tracker(self) -> ContextManager:
+ """
+ :return: tracker function from upload_tracker.UploadTracker
+ """
+ pass
+
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _value_type(
+ cls,
+ ) -> tensorboard_time_series.TensorboardTimeSeries.ValueType:
+ """
+ :return: Value type of the time series.
+ """
+ pass
-class _TensorBatchedRequestSender(object):
- """Helper class for building WriteTensor() requests that fit under a size limit.
+ @abc.abstractmethod
+ def _create_data_point(
+ self,
+ run_name: str,
+ event: tf.compat.v1.Event,
+ value: tf.compat.v1.Summary.Value,
+ metadata: tf.compat.v1.SummaryMetadata,
+ ) -> tensorboard_data.TimeSeriesDataPoint:
+ """
+ Creates data point protos for sending to the OnePlatform API.
+ """
+ pass
+
+ def _validate(
+ self,
+ point: tensorboard_data.TimeSeriesDataPoint,
+ event: tf.compat.v1.Event,
+ value: tf.compat.v1.Summary.Value,
+ ):
+ """
+ Validations performed before including the data point to be sent to the
+ OnePlatform API.
+ """
+ return True
+
+
+class _ScalarBatchedRequestSender(_BaseBatchedRequestSender):
+ """Helper class for building requests that fit under a size limit.
This class accumulates a current request. `add_event(...)` may or may not
send the request (and start a new one). After all `add_event(...)` calls
are complete, a final call to `flush()` is needed to send the final request.
+
This class is not threadsafe. Use external synchronization if calling its
methods concurrently.
"""
+ _value_type = tensorboard_time_series.TensorboardTimeSeries.ValueType.SCALAR
+
def __init__(
self,
- run_resource_id: str,
+ experiment_resource_id: str,
api: TensorboardServiceClient,
rpc_rate_limiter: util.RateLimiter,
max_request_size: int,
- max_tensor_point_size: int,
tracker: upload_tracker.UploadTracker,
+ one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
):
- """Constructer for _TensorBatchedRequestSender.
+ """Constructor for _ScalarBatchedRequestSender.
Args:
- run_resource_id: The resource id for the run with the following format
- projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}
+ experiment_resource_id: The resource id for the experiment with the following format
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
api: TensorboardServiceStub
rpc_rate_limiter: until.RateLimiter to limit rate of this request sender
max_request_size: max number of bytes to send
tracker:
"""
- self._run_resource_id = run_resource_id
- self._api = api
- self._rpc_rate_limiter = rpc_rate_limiter
- self._byte_budget_manager = _ByteBudgetManager(max_request_size)
- self._max_tensor_point_size = max_tensor_point_size
- self._tracker = tracker
-
- # cache: map from Tensorboard tag to TimeSeriesData
- # cleared whenever a new request is created
- self._tag_to_time_series_data: Dict[str, tensorboard_data.TimeSeriesData] = {}
-
- self._time_series_resource_manager = _TimeSeriesResourceManager(
- run_resource_id, api
+ super().__init__(
+ experiment_resource_id,
+ api,
+ rpc_rate_limiter,
+ max_request_size,
+ tracker,
+ one_platform_resource_manager,
)
- self._new_request()
-
- def _new_request(self):
- """Allocates a new request and refreshes the budget."""
- self._request = tensorboard_service.WriteTensorboardRunDataRequest()
- self._tag_to_time_series_data.clear()
- self._num_values = 0
- self._request.tensorboard_run = self._run_resource_id
- self._byte_budget_manager.reset(self._request)
- self._num_values = 0
- self._num_values_skipped = 0
- self._tensor_bytes = 0
- self._tensor_bytes_skipped = 0
-
- def add_event(
- self,
- event: tf.compat.v1.Event,
- value: tf.compat.v1.Summary.Value,
- metadata: tf.compat.v1.SummaryMetadata,
- ):
- """Attempts to add the given event to the current request.
- If the event cannot be added to the current request because the byte
- budget is exhausted, the request is flushed, and the event is added
- to the next request.
- """
- try:
- self._add_event_internal(event, value, metadata)
- except _OutOfSpaceError:
- self.flush()
- # Try again. This attempt should never produce OutOfSpaceError
- # because we just flushed.
- try:
- self._add_event_internal(event, value, metadata)
- except _OutOfSpaceError:
- raise RuntimeError("add_event failed despite flush")
+ def _get_tracker(self) -> ContextManager:
+ return self._tracker.scalars_tracker(self._num_values)
- def _add_event_internal(
+ def _create_data_point(
self,
+ run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
- ):
- self._num_values += 1
- time_series_data_proto = self._tag_to_time_series_data.get(value.tag)
- if time_series_data_proto is None:
- time_series_data_proto = self._create_time_series_data(value.tag, metadata)
- self._create_point(time_series_data_proto, event, value)
-
- def flush(self):
- """Sends the active request after removing empty runs and tags.
+ ) -> tensorboard_data.TimeSeriesDataPoint:
+ scalar_proto = tensorboard_data.Scalar(
+ value=tensor_util.make_ndarray(value.tensor).item()
+ )
+ return tensorboard_data.TimeSeriesDataPoint(
+ step=event.step,
+ scalar=scalar_proto,
+ wall_time=timestamp.Timestamp(
+ seconds=int(event.wall_time),
+ nanos=int(round((event.wall_time % 1) * 10**9)),
+ ),
+ )
- Starts a new, empty active request.
- """
- request = self._request
- request.time_series_data = list(self._tag_to_time_series_data.values())
- _prune_empty_time_series(request)
- if not request.time_series_data:
- return
- self._rpc_rate_limiter.tick()
+class _TensorBatchedRequestSender(_BaseBatchedRequestSender):
+ """Helper class for building WriteTensor() requests that fit under a size limit.
- with _request_logger(request):
- with self._tracker.tensors_tracker(
- self._num_values,
- self._num_values_skipped,
- self._tensor_bytes,
- self._tensor_bytes_skipped,
- ):
- try:
- self._api.write_tensorboard_run_data(
- tensorboard_run=self._run_resource_id,
- time_series_data=request.time_series_data,
- )
- except grpc.RpcError as e:
- if e.code() == grpc.StatusCode.NOT_FOUND:
- raise ExperimentNotFoundError()
- logger.error("Upload call failed with error %s", e)
+ This class accumulates a current request. `add_event(...)` may or may not
+ send the request (and start a new one). After all `add_event(...)` calls
+ are complete, a final call to `flush()` is needed to send the final request.
+ This class is not threadsafe. Use external synchronization if calling its
+ methods concurrently.
+ """
- self._new_request()
+ _value_type = tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR
- def _create_time_series_data(
- self, tag_name: str, metadata: tf.compat.v1.SummaryMetadata
- ) -> tensorboard_data.TimeSeriesData:
- """Adds a time_series for the tag_name, if there's space.
+ def __init__(
+ self,
+ experiment_resource_id: str,
+ api: TensorboardServiceClient,
+ rpc_rate_limiter: util.RateLimiter,
+ max_request_size: int,
+ max_tensor_point_size: int,
+ tracker: upload_tracker.UploadTracker,
+ one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
+ ):
+ """Constructor for _TensorBatchedRequestSender.
Args:
- tag_name: String name of the tag to add (as `value.tag`).
- metadata: SummaryMetadata of the event.
-
- Returns:
- The TimeSeriesData in _request proto with the given tag name.
-
- Raises:
- _OutOfSpaceError: If adding the tag would exceed the remaining
- request budget.
+ experiment_resource_id: The resource id for the experiment with the following format
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
+ api: TensorboardServiceStub
+ rpc_rate_limiter: until.RateLimiter to limit rate of this request sender
+ max_request_size: max number of bytes to send
+ tracker:
"""
- time_series_data_proto = tensorboard_data.TimeSeriesData(
- tensorboard_time_series_id=self._time_series_resource_manager.get_or_create(
- tag_name,
- lambda: tensorboard_time_series.TensorboardTimeSeries(
- display_name=tag_name,
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR,
- plugin_name=metadata.plugin_data.plugin_name,
- plugin_data=metadata.plugin_data.content,
- ),
- ).name.split("/")[-1],
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.TENSOR,
+ super().__init__(
+ experiment_resource_id,
+ api,
+ rpc_rate_limiter,
+ max_request_size,
+ tracker,
+ one_platform_resource_manager,
)
+ self._max_tensor_point_size = max_tensor_point_size
- self._request.time_series_data.extend([time_series_data_proto])
- self._byte_budget_manager.add_time_series(time_series_data_proto)
- self._tag_to_time_series_data[tag_name] = time_series_data_proto
- return time_series_data_proto
+ def _new_request(self):
+ """Allocates a new request and refreshes the budget."""
+ super()._new_request()
+ self._num_values = 0
+ self._num_values_skipped = 0
+ self._tensor_bytes = 0
+ self._tensor_bytes_skipped = 0
- def _create_point(
+ def _get_tracker(self) -> ContextManager:
+ return self._tracker.tensors_tracker(
+ self._num_values,
+ self._num_values_skipped,
+ self._tensor_bytes,
+ self._tensor_bytes_skipped,
+ )
+
+ def _create_data_point(
self,
- time_series_proto: tensorboard_data.TimeSeriesData,
+ run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
- ):
- """Adds a tensor point to the given tag, if there's space.
-
- Args:
- tag_proto: `WriteTensorRequest.Tag` proto to which to add a point.
- event: Enclosing `Event` proto with the step and wall time data.
- value: Tensor `Summary.Value` proto with the actual tensor data.
-
- Raises:
- _OutOfSpaceError: If adding the point would exceed the remaining
- request budget.
- """
- point = tensorboard_data.TimeSeriesDataPoint(
+ metadata: tf.compat.v1.SummaryMetadata,
+ ) -> tensorboard_data.TimeSeriesDataPoint:
+ return tensorboard_data.TimeSeriesDataPoint(
step=event.step,
tensor=tensorboard_data.TensorboardTensor(
value=value.tensor.SerializeToString()
),
wall_time=timestamp.Timestamp(
seconds=int(event.wall_time),
- nanos=int(round((event.wall_time % 1) * 10 ** 9)),
+ nanos=int(round((event.wall_time % 1) * 10**9)),
),
)
+ def _validate(
+ self,
+ point: tensorboard_data.TimeSeriesDataPoint,
+ event: tf.compat.v1.Event,
+ value: tf.compat.v1.Summary.Value,
+ ):
self._num_values += 1
tensor_size = len(point.tensor.value)
self._tensor_bytes += tensor_size
@@ -1059,32 +1137,19 @@ def _create_point(
)
self._num_values_skipped += 1
self._tensor_bytes_skipped += tensor_size
- return
-
- self._validate_tensor_value(
- value.tensor, value.tag, event.step, event.wall_time
- )
-
- time_series_proto.values.extend([point])
+ return False
try:
- self._byte_budget_manager.add_point(point)
- except _OutOfSpaceError:
- time_series_proto.values.pop()
- raise
-
- def _validate_tensor_value(self, tensor_proto, tag, step, wall_time):
- """Validate a TensorProto by attempting to parse it."""
- try:
- tensor_util.make_ndarray(tensor_proto)
+ tensor_util.make_ndarray(value.tensor)
except ValueError as error:
raise ValueError(
"The uploader failed to upload a tensor. This seems to be "
"due to a malformation in the tensor, which may be caused by "
"a bug in the process that wrote the tensor.\n\n"
"The tensor has tag '%s' and is at step %d and wall_time %.6f.\n\n"
- "Original error:\n%s" % (tag, step, wall_time, error)
+ "Original error:\n%s" % (value.tag, event.step, event.wall_time, error)
)
+ return True
class _ByteBudgetManager(object):
@@ -1108,7 +1173,9 @@ def __init__(self, max_bytes: int):
self._byte_budget = None # type: int
self._max_bytes = max_bytes
- def reset(self, base_request: tensorboard_service.WriteTensorboardRunDataRequest):
+ def reset(
+ self, base_request: tensorboard_service.WriteTensorboardExperimentDataRequest
+ ):
"""Resets the byte budget and calculates the cost of the base request.
Args:
@@ -1176,7 +1243,7 @@ def add_point(self, point_proto: tensorboard_data.TimeSeriesDataPoint):
self._byte_budget -= cost
-class _BlobRequestSender(object):
+class _BlobRequestSender(_BaseBatchedRequestSender):
"""Uploader for blob-type event data.
Unlike the other types, this class does not accumulate events in batches;
@@ -1187,9 +1254,11 @@ class _BlobRequestSender(object):
methods concurrently.
"""
+ _value_type = tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE
+
def __init__(
self,
- run_resource_id: str,
+ experiment_resource_id: str,
api: TensorboardServiceClient,
rpc_rate_limiter: util.RateLimiter,
max_blob_request_size: int,
@@ -1197,79 +1266,59 @@ def __init__(
blob_storage_bucket: storage.Bucket,
blob_storage_folder: str,
tracker: upload_tracker.UploadTracker,
+ one_platform_resource_manager: uploader_utils.OnePlatformResourceManager,
):
- self._run_resource_id = run_resource_id
- self._api = api
- self._rpc_rate_limiter = rpc_rate_limiter
- self._max_blob_request_size = max_blob_request_size
- self._max_blob_size = max_blob_size
- self._tracker = tracker
- self._time_series_resource_manager = _TimeSeriesResourceManager(
- run_resource_id, api
+ super().__init__(
+ experiment_resource_id,
+ api,
+ rpc_rate_limiter,
+ max_blob_request_size,
+ tracker,
+ one_platform_resource_manager,
)
-
+ self._max_blob_size = max_blob_size
self._bucket = blob_storage_bucket
self._folder = blob_storage_folder
- self._new_request()
-
def _new_request(self):
- """Declares the previous event complete."""
- self._event = None
- self._value = None
- self._metadata = None
+ super()._new_request()
+ self._blob_sizes = 0
- def add_event(
+ def _get_tracker(self) -> ContextManager:
+ return self._tracker.blob_tracker(0)
+
+ def _create_data_point(
self,
+ run_name: str,
event: tf.compat.v1.Event,
value: tf.compat.v1.Summary.Value,
metadata: tf.compat.v1.SummaryMetadata,
- ):
- """Attempts to add the given event to the current request.
-
- If the event cannot be added to the current request because the byte
- budget is exhausted, the request is flushed, and the event is added
- to the next request.
- """
- if self._value:
- raise RuntimeError("Tried to send blob while another is pending")
- self._event = event # provides step and possibly plugin_name
- self._value = value
- self._blobs = tensor_util.make_ndarray(self._value.tensor)
- if self._blobs.ndim == 1:
- self._metadata = metadata
- self.flush()
- else:
+ ) -> tensorboard_data.TimeSeriesDataPoint:
+ blobs = tensor_util.make_ndarray(value.tensor)
+ if blobs.ndim != 1:
logger.warning(
"A blob sequence must be represented as a rank-1 Tensor. "
"Provided data has rank %d, for run %s, tag %s, step %s ('%s' plugin) .",
- self._blobs.ndim,
- self._run_resource_id,
- self._value.tag,
- self._event.step,
+ blobs.ndim,
+ run_name,
+ value.tag,
+ event.step,
metadata.plugin_data.plugin_name,
)
- # Skip this upload.
- self._new_request()
-
- def flush(self):
- """Sends the current blob sequence fully, and clears it to make way for the next."""
- if not self._value:
- self._new_request()
- return
+ return None
- time_series_proto = self._time_series_resource_manager.get_or_create(
- self._value.tag,
- lambda: tensorboard_time_series.TensorboardTimeSeries(
- display_name=self._value.tag,
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE,
- plugin_name=self._metadata.plugin_data.plugin_name,
- plugin_data=self._metadata.plugin_data.content,
- ),
- )
m = re.match(
".*/tensorboards/(.*)/experiments/(.*)/runs/(.*)/timeSeries/(.*)",
- time_series_proto.name,
+ self._one_platform_resource_manager.get_time_series_resource_name(
+ run_name,
+ value.tag,
+ lambda: tensorboard_time_series.TensorboardTimeSeries(
+ display_name=value.tag,
+ value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE,
+ plugin_name=metadata.plugin_data.plugin_name,
+ plugin_data=metadata.plugin_data.content,
+ ),
+ ),
)
blob_path_prefix = "tensorboard-{}/{}/{}/{}".format(m[1], m[2], m[3], m[4])
blob_path_prefix = (
@@ -1278,16 +1327,15 @@ def flush(self):
else blob_path_prefix
)
sent_blob_ids = []
- for blob in self._blobs:
- self._rpc_rate_limiter.tick()
+ for blob in blobs:
with self._tracker.blob_tracker(len(blob)) as blob_tracker:
blob_id = self._send_blob(blob, blob_path_prefix)
if blob_id is not None:
sent_blob_ids.append(str(blob_id))
- blob_tracker.mark_uploaded(blob_id is not None)
+ blob_tracker.mark_uploaded(blob_id is not None)
- data_point = tensorboard_data.TimeSeriesDataPoint(
- step=self._event.step,
+ return tensorboard_data.TimeSeriesDataPoint(
+ step=event.step,
blobs=tensorboard_data.TensorboardBlobSequence(
values=[
tensorboard_data.TensorboardBlob(id=blob_id)
@@ -1295,37 +1343,11 @@ def flush(self):
]
),
wall_time=timestamp.Timestamp(
- seconds=int(self._event.wall_time),
- nanos=int(round((self._event.wall_time % 1) * 10 ** 9)),
+ seconds=int(event.wall_time),
+ nanos=int(round((event.wall_time % 1) * 10**9)),
),
)
- time_series_data_proto = tensorboard_data.TimeSeriesData(
- tensorboard_time_series_id=time_series_proto.name.split("/")[-1],
- value_type=tensorboard_time_series.TensorboardTimeSeries.ValueType.BLOB_SEQUENCE,
- values=[data_point],
- )
- request = tensorboard_service.WriteTensorboardRunDataRequest(
- time_series_data=[time_series_data_proto]
- )
-
- _prune_empty_time_series(request)
- if not request.time_series_data:
- return
-
- with _request_logger(request):
- try:
- self._api.write_tensorboard_run_data(
- tensorboard_run=self._run_resource_id,
- time_series_data=request.time_series_data,
- )
- except grpc.RpcError as e:
- if e.code() == grpc.StatusCode.NOT_FOUND:
- raise ExperimentNotFoundError()
- logger.error("Upload call failed with error %s", e)
-
- self._new_request()
-
def _send_blob(self, blob, blob_path_prefix):
"""Sends a single blob to a GCS bucket in the consumer project.
@@ -1350,19 +1372,6 @@ def _send_blob(self, blob, blob_path_prefix):
return blob_id
-@contextlib.contextmanager
-def _request_logger(request: tensorboard_service.WriteTensorboardRunDataRequest):
- """Context manager to log request size and duration."""
- upload_start_time = time.time()
- request_bytes = request._pb.ByteSize() # pylint: disable=protected-access
- logger.info("Trying request of %d bytes", request_bytes)
- yield
- upload_duration_secs = time.time() - upload_start_time
- logger.info(
- "Upload of (%d bytes) took %.3f seconds", request_bytes, upload_duration_secs,
- )
-
-
def _varint_cost(n: int):
"""Computes the size of `n` encoded as an unsigned base-128 varint.
@@ -1433,7 +1442,8 @@ def _filtered_graph_bytes(graph_bytes: bytes):
# a combination of mysterious circumstances.
except (message.DecodeError, RuntimeWarning):
logger.warning(
- "Could not parse GraphDef of size %d. Skipping.", len(graph_bytes),
+ "Could not parse GraphDef of size %d. Skipping.",
+ len(graph_bytes),
)
return None
# Use the default filter parameters:
diff --git a/google/cloud/aiplatform/tensorboard/uploader_main.py b/google/cloud/aiplatform/tensorboard/uploader_main.py
index ebd4aa5147..e6adb9cbe2 100644
--- a/google/cloud/aiplatform/tensorboard/uploader_main.py
+++ b/google/cloud/aiplatform/tensorboard/uploader_main.py
@@ -28,8 +28,11 @@
from tensorboard.plugins.image import metadata as images_metadata
from tensorboard.plugins.graph import metadata as graphs_metadata
+from google.api_core import exceptions
from google.cloud import storage
from google.cloud import aiplatform
+from google.cloud.aiplatform.constants import base as constants
+from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.tensorboard import uploader
from google.cloud.aiplatform.utils import TensorboardClientWithOverride
@@ -89,14 +92,15 @@ def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
- aiplatform.constants.API_BASE_PATH = FLAGS.api_uri
+ constants.API_BASE_PATH = FLAGS.api_uri
m = re.match(
"projects/(.*)/locations/(.*)/tensorboards/.*", FLAGS.tensorboard_resource_name
)
project_id = m[1]
region = m[2]
api_client = aiplatform.initializer.global_config.create_client(
- client_class=TensorboardClientWithOverride, location_override=region,
+ client_class=TensorboardClientWithOverride,
+ location_override=region,
)
try:
@@ -123,9 +127,14 @@ def main(argv):
exitcode=0,
)
+ experiment_name = FLAGS.experiment_name
+ experiment_display_name = get_experiment_display_name_with_override(
+ experiment_name, FLAGS.experiment_display_name, project_id, region
+ )
+
tb_uploader = uploader.TensorBoardUploader(
- experiment_name=FLAGS.experiment_name,
- experiment_display_name=FLAGS.experiment_display_name,
+ experiment_name=experiment_name,
+ experiment_display_name=experiment_display_name,
tensorboard_resource_name=tensorboard.name,
blob_storage_bucket=blob_storage_bucket,
blob_storage_folder=blob_storage_folder,
@@ -146,10 +155,22 @@ def main(argv):
tb_uploader.get_experiment_resource_name().replace("/", "+"),
)
)
- if FLAGS.one_shot:
- tb_uploader._upload_once() # pylint: disable=protected-access
- else:
- tb_uploader.start_uploading()
+ tb_uploader.start_uploading()
+
+
+def get_experiment_display_name_with_override(
+ experiment_name, experiment_display_name, project_id, region
+):
+ if experiment_name.isdecimal() and not experiment_display_name:
+ try:
+ return jobs.CustomJob.get(
+ resource_name=experiment_name,
+ project=project_id,
+ location=region,
+ ).display_name
+ except exceptions.NotFound:
+ return experiment_display_name
+ return experiment_display_name
def flags_parser(args):
diff --git a/google/cloud/aiplatform/tensorboard/uploader_utils.py b/google/cloud/aiplatform/tensorboard/uploader_utils.py
new file mode 100644
index 0000000000..849a12fef6
--- /dev/null
+++ b/google/cloud/aiplatform/tensorboard/uploader_utils.py
@@ -0,0 +1,480 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Shared utils for tensorboard log uploader."""
+import abc
+import contextlib
+import json
+import logging
+import re
+import time
+from typing import Callable, Dict, Generator, Optional, List, Tuple
+import uuid
+
+from tensorboard.util import tb_logging
+
+from google.api_core import exceptions
+from google.cloud import storage
+from google.cloud.aiplatform.compat.types import tensorboard_run
+from google.cloud.aiplatform.compat.types import tensorboard_service
+from google.cloud.aiplatform.compat.types import tensorboard_time_series
+from google.cloud.aiplatform.compat.services import tensorboard_service_client
+
+TensorboardServiceClient = tensorboard_service_client.TensorboardServiceClient
+
+logger = tb_logging.get_logger()
+logger.setLevel(logging.WARNING)
+
+
+class ExistingResourceNotFoundError(RuntimeError):
+ """Resource could not be created or retrieved."""
+
+
+class RequestSender(object):
+ """A base class for additional request sender objects.
+
+ Currently just used for typing.
+ """
+
+ @abc.abstractmethod
+ def send_requests(run_name: str):
+ """Sends any request for the run."""
+ pass
+
+
+class OnePlatformResourceManager(object):
+ """Helper class managing One Platform resources."""
+
+ CREATE_RUN_BATCH_SIZE = 1000
+ CREATE_TIME_SERIES_BATCH_SIZE = 1000
+
+ def __init__(self, experiment_resource_name: str, api: TensorboardServiceClient):
+ """Constructor for OnePlatformResourceManager.
+
+ Args:
+ experiment_resource_name (str):
+ Required. The resource id for the run with the following format
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}
+ api (TensorboardServiceClient):
+ Required. TensorboardServiceStub for calling various tensorboard services.
+ """
+ self._experiment_resource_name = experiment_resource_name
+ self._api = api
+ self._run_name_to_run_resource_name: Dict[str, str] = {}
+ self._run_tag_name_to_time_series_name: Dict[(str, str), str] = {}
+
+ def batch_create_runs(
+ self, run_names: List[str]
+ ) -> List[tensorboard_run.TensorboardRun]:
+ """Batch creates TensorboardRuns.
+
+ Args:
+ run_names: a list of run_names for creating the TensorboardRuns.
+ Returns:
+ the created TensorboardRuns
+ """
+ batch_size = OnePlatformResourceManager.CREATE_RUN_BATCH_SIZE
+ created_runs = []
+ for i in range(0, len(run_names), batch_size):
+ one_batch_run_names = run_names[i : i + batch_size]
+ tb_run_requests = [
+ tensorboard_service.CreateTensorboardRunRequest(
+ parent=self._experiment_resource_name,
+ tensorboard_run=tensorboard_run.TensorboardRun(
+ display_name=run_name
+ ),
+ tensorboard_run_id=str(uuid.uuid4()),
+ )
+ for run_name in one_batch_run_names
+ ]
+
+ tb_runs = self._api.batch_create_tensorboard_runs(
+ parent=self._experiment_resource_name,
+ requests=tb_run_requests,
+ ).tensorboard_runs
+
+ self._run_name_to_run_resource_name.update(
+ {run.display_name: run.name for run in tb_runs}
+ )
+
+ created_runs.extend(tb_runs)
+
+ return created_runs
+
+ def batch_create_time_series(
+ self,
+ run_tag_name_to_time_series: Dict[
+ Tuple[str, str], tensorboard_time_series.TensorboardTimeSeries
+ ],
+ ) -> List[tensorboard_time_series.TensorboardTimeSeries]:
+ """Batch creates TensorboardTimeSeries.
+
+ Args:
+ run_tag_name_to_time_series: a dictionary of
+ (run_name, tag_name) to TensorboardTimeSeries proto, containing
+ the TensorboardTimeSeries to create.
+ Returns:
+ the created TensorboardTimeSeries
+ """
+ batch_size = OnePlatformResourceManager.CREATE_TIME_SERIES_BATCH_SIZE
+ run_tag_name_to_time_series_entries = list(run_tag_name_to_time_series.items())
+ run_resource_name_to_run_name = {
+ v: k for k, v in self._run_name_to_run_resource_name.items()
+ }
+ created_time_series = []
+ for i in range(0, len(run_tag_name_to_time_series_entries), batch_size):
+ requests = [
+ tensorboard_service.CreateTensorboardTimeSeriesRequest(
+ parent=self._run_name_to_run_resource_name[run_name],
+ tensorboard_time_series=time_series,
+ )
+ for (
+ (run_name, tag_name),
+ time_series,
+ ) in run_tag_name_to_time_series_entries[i : i + batch_size]
+ ]
+
+ time_series = self._api.batch_create_tensorboard_time_series(
+ parent=self._experiment_resource_name,
+ requests=requests,
+ ).tensorboard_time_series
+
+ self._run_tag_name_to_time_series_name.update(
+ {
+ (
+ run_resource_name_to_run_name[
+ ts.name[: ts.name.index("/timeSeries")]
+ ],
+ ts.display_name,
+ ): ts.name
+ for ts in time_series
+ }
+ )
+
+ created_time_series.extend(time_series)
+
+ return created_time_series
+
+ def get_run_resource_name(self, run_name: str) -> str:
+ """
+ Get the resource name of the run if it exists, otherwise creates the run
+ on One Platform before returning its resource name.
+
+ Args:
+ run_name (str):
+ Required. The name of the run.
+
+ Returns:
+ run_resource (str):
+ Resource name of the run.
+ """
+ if run_name not in self._run_name_to_run_resource_name:
+ tb_run = self._create_or_get_run_resource(run_name)
+ self._run_name_to_run_resource_name[run_name] = tb_run.name
+ return self._run_name_to_run_resource_name[run_name]
+
+ def _create_or_get_run_resource(
+ self, run_name: str
+ ) -> tensorboard_run.TensorboardRun:
+ """Creates a new run resource in current tensorboard experiment resource.
+
+ Args:
+ run_name (str):
+ Required. The display name of this run.
+
+ Returns:
+ tb_run (tensorboard_run.TensorboardRun):
+ The TensorboardRun given the run_name.
+
+ Raises:
+ ExistingResourceNotFoundError:
+ Run name could not be found in resource list.
+ exceptions.InvalidArgument:
+ run_name argument is invalid.
+ """
+ tb_run = tensorboard_run.TensorboardRun()
+ tb_run.display_name = run_name
+ try:
+ tb_run = self._api.create_tensorboard_run(
+ parent=self._experiment_resource_name,
+ tensorboard_run=tb_run,
+ tensorboard_run_id=str(uuid.uuid4()),
+ )
+ except exceptions.InvalidArgument as e:
+ # If the run name already exists then retrieve it
+ if "already exist" in e.message:
+ runs_pages = self._api.list_tensorboard_runs(
+ parent=self._experiment_resource_name
+ )
+ for tb_run in runs_pages:
+ if tb_run.display_name == run_name:
+ break
+
+ if tb_run.display_name != run_name:
+ raise ExistingResourceNotFoundError(
+ "Run with name %s already exists but is not resource list."
+ % run_name
+ )
+ else:
+ raise
+ return tb_run
+
+ def get_time_series_resource_name(
+ self,
+ run_name: str,
+ tag_name: str,
+ time_series_resource_creator: Callable[
+ [], tensorboard_time_series.TensorboardTimeSeries
+ ],
+ ) -> str:
+ """
+ Get the resource name of the time series corresponding to the tag, if it
+ exists, otherwise creates the time series on One Platform before
+ returning its resource name.
+
+ Args:
+ run_name (str):
+ Required. The name of the run.
+ tag_name (str):
+ Required. The name of the tag.
+ time_series_resource_creator (tensorboard_time_series.TensorboardTimeSeries):
+ Required. A constructor used for creating the time series on One Platform.
+
+ Returns:
+ time_series_name (str):
+ Resource name of the time series
+ """
+ if (run_name, tag_name) not in self._run_tag_name_to_time_series_name:
+ time_series = self._create_or_get_time_series(
+ self.get_run_resource_name(run_name),
+ tag_name,
+ time_series_resource_creator,
+ )
+ self._run_tag_name_to_time_series_name[
+ (run_name, tag_name)
+ ] = time_series.name
+ return self._run_tag_name_to_time_series_name[(run_name, tag_name)]
+
+ def _create_or_get_time_series(
+ self,
+ run_resource_name: str,
+ tag_name: str,
+ time_series_resource_creator: Callable[
+ [], tensorboard_time_series.TensorboardTimeSeries
+ ],
+ ) -> tensorboard_time_series.TensorboardTimeSeries:
+ """
+ Get a time series resource with given tag_name, and create a new one on
+ OnePlatform if not present.
+
+ Args:
+ tag_name (str):
+ Required. The tag name of the time series in the Tensorboard log dir.
+ time_series_resource_creator (Callable[[], tensorboard_time_series.TensorboardTimeSeries):
+ Required. A callable that produces a TimeSeries for creation.
+
+ Returns:
+ time_series (tensorboard_time_series.TensorboardTimeSeries):
+ A created or existing tensorboard_time_series.TensorboardTimeSeries.
+
+ Raises:
+ exceptions.InvalidArgument:
+ Invalid run_resource_name, tag_name, or time_series_resource_creator.
+ ExistingResourceNotFoundError:
+ Could not find the resource given the tag name.
+ ValueError:
+ More than one time series with the resource name was found.
+ """
+ time_series = time_series_resource_creator()
+ time_series.display_name = tag_name
+ try:
+ time_series = self._api.create_tensorboard_time_series(
+ parent=run_resource_name, tensorboard_time_series=time_series
+ )
+ except exceptions.InvalidArgument as e:
+ # If the time series display name already exists then retrieve it
+ if "already exist" in e.message:
+ list_of_time_series = self._api.list_tensorboard_time_series(
+ request=tensorboard_service.ListTensorboardTimeSeriesRequest(
+ parent=run_resource_name,
+ filter="display_name = {}".format(json.dumps(str(tag_name))),
+ )
+ )
+ num = 0
+ time_series = None
+
+ for ts in list_of_time_series:
+ num += 1
+ if num > 1:
+ break
+ time_series = ts
+
+ if not time_series:
+ raise ExistingResourceNotFoundError(
+ "Could not find time series resource with display name: {}".format(
+ tag_name
+ )
+ )
+
+ if num != 1:
+ raise ValueError(
+ "More than one time series resource found with display_name: {}".format(
+ tag_name
+ )
+ )
+ else:
+ raise
+ return time_series
+
+
+class TimeSeriesResourceManager(object):
+ """Helper class managing Time Series resources."""
+
+ def __init__(self, run_resource_id: str, api: TensorboardServiceClient):
+ """Constructor for TimeSeriesResourceManager.
+
+ Args:
+ run_resource_id (str):
+ Required. The resource id for the run with the following format.
+ projects/{project}/locations/{location}/tensorboards/{tensorboard}/experiments/{experiment}/runs/{run}
+ api (TensorboardServiceClient):
+ Required. A TensorboardServiceStub.
+ """
+ self._run_resource_id = run_resource_id
+ self._api = api
+ self._tag_to_time_series_proto: Dict[
+ str, tensorboard_time_series.TensorboardTimeSeries
+ ] = {}
+
+ def get_or_create(
+ self,
+ tag_name: str,
+ time_series_resource_creator: Callable[
+ [], tensorboard_time_series.TensorboardTimeSeries
+ ],
+ ) -> tensorboard_time_series.TensorboardTimeSeries:
+ """
+ Get a time series resource with given tag_name, and create a new one on
+ OnePlatform if not present.
+
+ Args:
+ tag_name (str):
+ Required. The tag name of the time series in the Tensorboard log dir.
+ time_series_resource_creator (Callable[[], tensorboard_time_series.TensorboardTimeSeries]):
+ Required. A callable that produces a TimeSeries for creation.
+
+ Returns:
+ time_series (tensorboard_time_series.TensorboardTimeSeries):
+ A new or existing tensorboard_time_series.TensorboardTimeSeries.
+
+ Raises:
+ exceptions.InvalidArgument:
+ The tag_name or time_series_resource_creator is an invalid argument
+ to create_tensorboard_time_series api call.
+ ExistingResourceNotFoundError:
+ Could not find the resource given the tag name.
+ ValueError:
+ More than one time series with the resource name was found.
+ """
+ if tag_name in self._tag_to_time_series_proto:
+ return self._tag_to_time_series_proto[tag_name]
+
+ time_series = time_series_resource_creator()
+ time_series.display_name = tag_name
+ try:
+ time_series = self._api.create_tensorboard_time_series(
+ parent=self._run_resource_id, tensorboard_time_series=time_series
+ )
+ except exceptions.InvalidArgument as e:
+ # If the time series display name already exists then retrieve it
+ if "already exist" in e.message:
+ list_of_time_series = self._api.list_tensorboard_time_series(
+ request=tensorboard_service.ListTensorboardTimeSeriesRequest(
+ parent=self._run_resource_id,
+ filter="display_name = {}".format(json.dumps(str(tag_name))),
+ )
+ )
+ num = 0
+ time_series = None
+
+ for ts in list_of_time_series:
+ num += 1
+ if num > 1:
+ break
+ time_series = ts
+
+ if not time_series:
+ raise ExistingResourceNotFoundError(
+ "Could not find time series resource with display name: {}".format(
+ tag_name
+ )
+ )
+
+ if num != 1:
+ raise ValueError(
+ "More than one time series resource found with display_name: {}".format(
+ tag_name
+ )
+ )
+ else:
+ raise
+
+ self._tag_to_time_series_proto[tag_name] = time_series
+ return time_series
+
+
+def get_source_bucket(logdir: str) -> Optional[storage.Bucket]:
+ """Returns a storage bucket object given a log directory.
+
+ Args:
+ logdir (str):
+ Required. Path of the log directory.
+
+ Returns:
+ bucket (Optional[storage.Bucket]):
+ A bucket if the path is a gs bucket, None otherwise.
+ """
+ m = re.match(r"gs:\/\/(.*?)(?=\/|$)", logdir)
+ if not m:
+ return None
+ bucket = storage.Client().bucket(m[1])
+ return bucket
+
+
+@contextlib.contextmanager
+def request_logger(
+ request: tensorboard_service.WriteTensorboardRunDataRequest,
+) -> Generator[None, None, None]:
+ """Context manager to log request size and duration.
+
+ Args:
+ request (tensorboard_service.WriteTensorboardRunDataRequest):
+ Required. A request object that provides the size of the request.
+
+ Yields:
+ An empty response when the request logger has started.
+ """
+ upload_start_time = time.time()
+ request_bytes = request._pb.ByteSize() # pylint: disable=protected-access
+ logger.info("Trying request of %d bytes", request_bytes)
+ yield
+ upload_duration_secs = time.time() - upload_start_time
+ logger.info(
+ "Upload of (%d bytes) took %.3f seconds",
+ request_bytes,
+ upload_duration_secs,
+ )
diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py
index 91e061f4ba..b2a93d952c 100644
--- a/google/cloud/aiplatform/training_jobs.py
+++ b/google/cloud/aiplatform/training_jobs.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import datetime
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
@@ -22,12 +23,14 @@
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
-from google.cloud.aiplatform import constants
+from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
+from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
+from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.compat.types import (
env_var as gca_env_var,
@@ -39,12 +42,14 @@
from google.cloud.aiplatform.utils import _timestamped_gcs_dir
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils
+from google.cloud.aiplatform.utils import column_transformations_utils
from google.cloud.aiplatform.v1.schema.trainingjob import (
definition_v1 as training_job_inputs,
)
from google.rpc import code_pb2
+from google.rpc import status_pb2
import proto
@@ -60,22 +65,33 @@
]
)
+# _block_until_complete wait times
+_JOB_WAIT_TIME = 5 # start at five seconds
+_LOG_WAIT_TIME = 5
+_MAX_WAIT_TIME = 60 * 5 # 5 minute wait
+_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration
-class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
+
+class _TrainingJob(base.VertexAiStatefulResource):
client_class = utils.PipelineClientWithOverride
- _is_client_prediction_client = False
_resource_noun = "trainingPipelines"
_getter_method = "get_training_pipeline"
_list_method = "list_training_pipelines"
_delete_method = "delete_training_pipeline"
+ _parse_resource_name_method = "parse_training_pipeline_path"
+ _format_resource_name_method = "training_pipeline_path"
+
+ # Required by the done() method
+ _valid_done_states = _PIPELINE_COMPLETE_STATES
def __init__(
self,
- display_name: str,
+ display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
):
@@ -83,7 +99,7 @@ def __init__(
Args:
display_name (str):
- Required. The user-defined name of this TrainingPipeline.
+ Optional. The user-defined name of this TrainingPipeline.
project (str):
Optional project to retrieve model from. If not set, project set in
aiplatform.init will be used.
@@ -92,6 +108,16 @@ def __init__(
aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional credentials to use to retrieve the model.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -118,10 +144,15 @@ def __init__(
Overrides encryption_spec_key_name set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
utils.validate_display_name(display_name)
+ if labels:
+ utils.validate_labels(labels)
super().__init__(project=project, location=location, credentials=credentials)
self._display_name = display_name
+ self._labels = labels
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=training_encryption_spec_key_name
)
@@ -135,9 +166,30 @@ def __init__(
@abc.abstractmethod
def _supported_training_schemas(cls) -> Tuple[str]:
"""List of supported schemas for this training job."""
-
pass
+ @property
+ def start_time(self) -> Optional[datetime.datetime]:
+ """Time when the TrainingJob entered the `PIPELINE_STATE_RUNNING` for
+ the first time."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "start_time")
+
+ @property
+ def end_time(self) -> Optional[datetime.datetime]:
+ """Time when the TrainingJob resource entered the `PIPELINE_STATE_SUCCEEDED`,
+ `PIPELINE_STATE_FAILED`, `PIPELINE_STATE_CANCELLED` state."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "end_time")
+
+ @property
+ def error(self) -> Optional[status_pb2.Status]:
+ """Detailed error info for this TrainingJob resource. Only populated when
+ the TrainingJob's state is `PIPELINE_STATE_FAILED` or
+ `PIPELINE_STATE_CANCELLED`."""
+ self._sync_gca_resource()
+ return getattr(self._gca_resource, "error")
+
@classmethod
def get(
cls,
@@ -152,10 +204,10 @@ def get(
resource_name (str):
Required. A fully-qualified resource name or ID.
project (str):
- Optional project to retrieve dataset from. If not set, project
+ Optional project to retrieve training job from. If not set, project
set in aiplatform.init will be used.
location (str):
- Optional location to retrieve dataset from. If not set, location
+ Optional location to retrieve training job from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
@@ -166,7 +218,7 @@ def get(
doesn't match the custom training task definition.
Returns:
- An Vertex AI Training Job
+ A Vertex AI Training Job
"""
# Create job with dummy parameters
@@ -193,11 +245,68 @@ def get(
return self
+ @classmethod
+ def _get_and_return_subclass(
+ cls,
+ resource_name: str,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+ ) -> "_TrainingJob":
+ """Retrieve Training Job subclass for the given resource_name without
+ knowing the training_task_definition.
+
+ Example usage:
+ ```
+ aiplatform.training_jobs._TrainingJob._get_and_return_subclass(
+ 'projects/.../locations/.../trainingPipelines/12345'
+ )
+ # Returns:
+ ```
+
+ Args:
+ resource_name (str):
+ Required. A fully-qualified resource name or ID.
+ project (str):
+ Optional project to retrieve dataset from. If not set, project
+ set in aiplatform.init will be used.
+ location (str):
+ Optional location to retrieve dataset from. If not set, location
+ set in aiplatform.init will be used.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to upload this model. Overrides
+ credentials set in aiplatform.init.
+
+ Returns:
+ A Vertex AI Training Job
+ """
+
+ # Retrieve training pipeline resource before class construction
+ client = cls._instantiate_client(location=location, credentials=credentials)
+
+ gca_training_pipeline = getattr(client, cls._getter_method)(name=resource_name)
+
+ schema_uri = gca_training_pipeline.training_task_definition
+
+ # Collect all AutoML training job classes and CustomTrainingJob
+ class_list = [
+ c for c in cls.__subclasses__() if c.__name__.startswith("AutoML")
+ ] + [CustomTrainingJob]
+
+ # Identify correct training job subclass, construct and return object
+ for c in class_list:
+ if schema_uri in c._supported_training_schemas:
+ return c._empty_constructor(
+ project=project,
+ location=location,
+ credentials=credentials,
+ resource_name=resource_name,
+ )
+
@property
@abc.abstractmethod
def _model_upload_fail_string(self) -> str:
"""Helper property for model upload failure."""
-
pass
@abc.abstractmethod
@@ -212,10 +321,14 @@ def run(self) -> Optional[models.Model]:
def _create_input_data_config(
dataset: Optional[datasets._Dataset] = None,
annotation_schema_uri: Optional[str] = None,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
gcs_destination_uri_prefix: Optional[str] = None,
bigquery_destination: Optional[str] = None,
) -> Optional[gca_training_pipeline.InputDataConfig]:
@@ -233,7 +346,7 @@ def _create_input_data_config(
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files
+ [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
@@ -253,17 +366,35 @@ def _create_input_data_config(
and
``annotation_schema_uri``.
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -274,6 +405,17 @@ def _create_input_data_config(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
gcs_destination_uri_prefix (str):
Optional. The Google Cloud Storage location.
@@ -300,33 +442,97 @@ def _create_input_data_config(
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
+ Raises:
+ ValueError: When more than 1 type of split configuration is passed or when
+ the split configuration passed is incompatible with the dataset schema.
"""
input_data_config = None
if dataset:
- # Create fraction split spec
- fraction_split = gca_training_pipeline.FractionSplit(
- training_fraction=training_fraction_split,
- validation_fraction=validation_fraction_split,
- test_fraction=test_fraction_split,
- )
-
- # Create predefined split spec
+ # Initialize all possible splits
+ filter_split = None
predefined_split = None
- if predefined_split_column_name:
- if dataset._gca_resource.metadata_schema_uri not in (
- schema.dataset.metadata.tabular,
- schema.dataset.metadata.time_series,
+ timestamp_split = None
+ fraction_split = None
+
+ # Create filter split
+ if any(
+ [
+ training_filter_split is not None,
+ validation_filter_split is not None,
+ test_filter_split is not None,
+ ]
+ ):
+ if all(
+ [
+ training_filter_split is not None,
+ validation_filter_split is not None,
+ test_filter_split is not None,
+ ]
):
+ filter_split = gca_training_pipeline.FilterSplit(
+ training_filter=training_filter_split,
+ validation_filter=validation_filter_split,
+ test_filter=test_filter_split,
+ )
+ else:
raise ValueError(
- "A pre-defined split may only be used with a tabular or time series Dataset"
+ "All filter splits must be passed together or not at all"
)
+ # Create predefined split
+ if predefined_split_column_name:
predefined_split = gca_training_pipeline.PredefinedSplit(
key=predefined_split_column_name
)
- # Create GCS destination
+ # Create timestamp split or fraction split
+ if timestamp_split_column_name:
+ timestamp_split = gca_training_pipeline.TimestampSplit(
+ training_fraction=training_fraction_split,
+ validation_fraction=validation_fraction_split,
+ test_fraction=test_fraction_split,
+ key=timestamp_split_column_name,
+ )
+ elif any(
+ [
+ training_fraction_split is not None,
+ validation_fraction_split is not None,
+ test_fraction_split is not None,
+ ]
+ ):
+ fraction_split = gca_training_pipeline.FractionSplit(
+ training_fraction=training_fraction_split,
+ validation_fraction=validation_fraction_split,
+ test_fraction=test_fraction_split,
+ )
+
+ splits = [
+ split
+ for split in [
+ filter_split,
+ predefined_split,
+ timestamp_split_column_name,
+ fraction_split,
+ ]
+ if split is not None
+ ]
+
+ # Fallback to fraction split if nothing else is specified
+ if len(splits) == 0:
+ _LOGGER.info(
+ "No dataset split provided. The service will use a default split."
+ )
+ elif len(splits) > 1:
+ raise ValueError(
+ """Can only specify one of:
+ 1. training_filter_split, validation_filter_split, test_filter_split
+ 2. predefined_split_column_name
+ 3. timestamp_split_column_name, training_fraction_split, validation_fraction_split, test_fraction_split
+ 4. training_fraction_split, validation_fraction_split, test_fraction_split"""
+ )
+
+ # create GCS destination
gcs_destination = None
if gcs_destination_uri_prefix:
gcs_destination = gca_io.GcsDestination(
@@ -343,7 +549,9 @@ def _create_input_data_config(
# create input data config
input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=fraction_split,
+ filter_split=filter_split,
predefined_split=predefined_split,
+ timestamp_split=timestamp_split,
dataset_id=dataset.name,
annotation_schema_uri=annotation_schema_uri,
gcs_destination=gcs_destination,
@@ -357,14 +565,19 @@ def _run_job(
training_task_definition: str,
training_task_inputs: Union[dict, proto.Message],
dataset: Optional[datasets._Dataset],
- training_fraction_split: float,
- validation_fraction_split: float,
- test_fraction_split: float,
- annotation_schema_uri: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ annotation_schema_uri: Optional[str] = None,
model: Optional[gca_model.Model] = None,
gcs_destination_uri_prefix: Optional[str] = None,
bigquery_destination: Optional[str] = None,
+ create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Runs the training job.
@@ -392,19 +605,10 @@ def _run_job(
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
For tabular Datasets, all their data is exported to
training, to pick and choose from.
- training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
- test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files
+ [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
@@ -423,6 +627,36 @@ def _run_job(
``annotations_filter``
and
``annotation_schema_uri``.
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ validation_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -433,6 +667,17 @@ def _run_job(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
model (~.model.Model):
Optional. Describes the Model that may be uploaded (via
[ModelService.UploadMode][]) by this TrainingPipeline. The
@@ -479,6 +724,8 @@ def _run_job(
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
"""
input_data_config = self._create_input_data_config(
@@ -487,7 +734,11 @@ def _run_job(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
gcs_destination_uri_prefix=gcs_destination_uri_prefix,
bigquery_destination=bigquery_destination,
)
@@ -499,6 +750,7 @@ def _run_job(
training_task_inputs=training_task_inputs,
model_to_upload=model,
input_data_config=input_data_config,
+ labels=self._labels,
encryption_spec=self._training_encryption_spec,
)
@@ -507,6 +759,7 @@ def _run_job(
self.project, self.location
),
training_pipeline=training_pipeline,
+ timeout=create_request_timeout,
)
self._gca_resource = training_pipeline
@@ -594,7 +847,7 @@ def _get_model(self) -> Optional[models.Model]:
"""Helper method to get and instantiate the Model to Upload.
Returns:
- model: Vertex AI Model if training succeeded and produced an Vertex AI
+ model: Vertex AI Model if training succeeded and produced a Vertex AI
Model. None otherwise.
Raises:
@@ -611,24 +864,19 @@ def _get_model(self) -> Optional[models.Model]:
return None
if self._gca_resource.model_to_upload.name:
- fields = utils.extract_fields_from_resource_name(
- self._gca_resource.model_to_upload.name
- )
+ return models.Model(model_name=self._gca_resource.model_to_upload.name)
- return models.Model(
- fields.id, project=fields.project, location=fields.location,
- )
+ def _wait_callback(self):
+ """Callback performs custom logging during _block_until_complete. Override in subclass."""
+ pass
def _block_until_complete(self):
"""Helper method to block and check on job until complete."""
- # Used these numbers so failures surface fast
- wait = 5 # start at five seconds
- log_wait = 5
- max_wait = 60 * 5 # 5 minute wait
- multiplier = 2 # scale wait by 2 every iteration
+ log_wait = _LOG_WAIT_TIME
previous_time = time.time()
+
while self.state not in _PIPELINE_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
@@ -640,9 +888,10 @@ def _block_until_complete(self):
self._gca_resource.state,
)
)
- log_wait = min(log_wait * multiplier, max_wait)
+ log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
previous_time = current_time
- time.sleep(wait)
+ self._wait_callback()
+ time.sleep(_JOB_WAIT_TIME)
self._raise_failure()
@@ -675,16 +924,10 @@ def has_failed(self) -> bool:
def _dashboard_uri(self) -> str:
"""Helper method to compose the dashboard uri where training can be
viewed."""
- fields = utils.extract_fields_from_resource_name(self.resource_name)
- url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}"
+ fields = self._parse_resource_name(self.resource_name)
+ url = f"https://console.cloud.google.com/ai/platform/locations/{fields['location']}/training/{fields['training_pipeline']}?project={fields['project']}"
return url
- def _sync_gca_resource(self):
- """Helper method to sync the local gca_source against the service."""
- self._gca_resource = self.api_client.get_training_pipeline(
- name=self.resource_name
- )
-
@property
def _has_run(self) -> bool:
"""Helper property to check if this training job has been run."""
@@ -771,6 +1014,10 @@ def cancel(self) -> None:
)
self.api_client.cancel_training_pipeline(name=self.resource_name)
+ def wait_for_resource_creation(self) -> None:
+ """Waits until resource has been created."""
+ self._wait_for_resource_creation()
+
class _CustomTrainingJob(_TrainingJob):
"""ABC for Custom Training Pipelines.."""
@@ -779,6 +1026,7 @@ class _CustomTrainingJob(_TrainingJob):
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
container_uri: str,
model_serving_container_image_uri: Optional[str] = None,
@@ -795,6 +1043,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
@@ -899,6 +1148,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -928,11 +1187,14 @@ def __init__(
Bucket used to stage source and training artifacts. Overrides
staging_bucket set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
)
@@ -997,13 +1259,40 @@ def __init__(
"set using aiplatform.init(staging_bucket='gs://my-bucket')"
)
+ # Backing Custom Job resource is not known until after data preprocessing
+ # once Custom Job is known we log the console uri and the tensorboard uri
+ # this flags keeps that state so we don't log it multiple times
+ self._has_logged_custom_job = False
+ self._logged_web_access_uris = set()
+
+ @property
+ def network(self) -> Optional[str]:
+ """The full name of the Google Compute Engine
+ [network](https://cloud.google.com/vpc/docs/vpc#networks) to which this
+ CustomTrainingJob should be peered.
+
+ Takes the format `projects/{project}/global/networks/{network}`. Where
+ {project} is a project number, as in `12345`, and {network} is a network name.
+
+ Private services access must already be configured for the network. If left
+ unspecified, the CustomTrainingJob is not peered with any network.
+ """
+ # Return `network` value in training task inputs if set in Map
+ self._assert_gca_resource_is_available()
+ return self._gca_resource.training_task_inputs.get("network")
+
def _prepare_and_validate_run(
self,
model_display_name: Optional[str] = None,
- replica_count: int = 0,
+ model_labels: Optional[Dict[str, str]] = None,
+ replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: Optional[str] = None,
) -> Tuple[worker_spec_utils._DistributedTrainingSpec, Optional[gca_model.Model]]:
"""Create worker pool specs and managed model as well validating the
run.
@@ -1015,6 +1304,16 @@ def _prepare_and_validate_run(
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
replica_count (int):
The number of worker replicas. If replica count = 1 then one chief
replica will be provisioned. If replica_count > 1 the remainder will be
@@ -1027,6 +1326,17 @@ def _prepare_and_validate_run(
NVIDIA_TESLA_T4
accelerator_count (int):
The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Type of the boot disk, default is `pd-ssd`.
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Size in GB of the boot disk, default is 100GB.
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ Optional. The type of machine to use for reduction server.
Returns:
Worker pools specs and managed model for run.
@@ -1054,17 +1364,28 @@ def _prepare_and_validate_run(
model_display_name = model_display_name or self._display_name + "-model"
# validates args and will raise
- worker_pool_specs = worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_count=accelerator_count,
- accelerator_type=accelerator_type,
- ).pool_specs
+ worker_pool_specs = (
+ worker_spec_utils._DistributedTrainingSpec.chief_worker_pool(
+ replica_count=replica_count,
+ machine_type=machine_type,
+ accelerator_count=accelerator_count,
+ accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
+ reduction_server_replica_count=reduction_server_replica_count,
+ reduction_server_machine_type=reduction_server_machine_type,
+ ).pool_specs
+ )
managed_model = self._managed_model
if model_display_name:
utils.validate_display_name(model_display_name)
managed_model.display_name = model_display_name
+ if model_labels:
+ utils.validate_labels(model_labels)
+ managed_model.labels = model_labels
+ else:
+ managed_model.labels = self._labels
else:
managed_model = None
@@ -1076,6 +1397,10 @@ def _prepare_training_task_inputs_and_output_dir(
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
) -> Tuple[Dict, str]:
"""Prepares training task inputs and output directory for custom job.
@@ -1093,6 +1418,32 @@ def _prepare_training_task_inputs_and_output_dir(
should be peered. For example, projects/12345/global/networks/myVPC.
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
Returns:
Training task inputs and Output directory for custom job.
"""
@@ -1113,9 +1464,93 @@ def _prepare_training_task_inputs_and_output_dir(
training_task_inputs["service_account"] = service_account
if network:
training_task_inputs["network"] = network
+ if tensorboard:
+ training_task_inputs["tensorboard"] = tensorboard
+ if enable_web_access:
+ training_task_inputs["enable_web_access"] = enable_web_access
+
+ if timeout or restart_job_on_worker_restart:
+ timeout = f"{timeout}s" if timeout else None
+ scheduling = {
+ "timeout": timeout,
+ "restart_job_on_worker_restart": restart_job_on_worker_restart,
+ }
+ training_task_inputs["scheduling"] = scheduling
return training_task_inputs, base_output_dir
+ @property
+ def web_access_uris(self) -> Dict[str, str]:
+ """Get the web access uris of the backing custom job.
+
+ Returns:
+ (Dict[str, str]):
+ Web access uris of the backing custom job.
+ """
+ web_access_uris = dict()
+ if (
+ self._gca_resource.training_task_metadata
+ and self._gca_resource.training_task_metadata.get("backingCustomJob")
+ ):
+ custom_job_resource_name = self._gca_resource.training_task_metadata.get(
+ "backingCustomJob"
+ )
+ custom_job = jobs.CustomJob.get(resource_name=custom_job_resource_name)
+
+ web_access_uris = dict(custom_job.web_access_uris)
+
+ return web_access_uris
+
+ def _log_web_access_uris(self):
+ """Helper method to log the web access uris of the backing custom job"""
+ for worker, uri in self.web_access_uris.items():
+ if uri not in self._logged_web_access_uris:
+ _LOGGER.info(
+ "%s %s access the interactive shell terminals for the backing custom job:\n%s:\n%s"
+ % (
+ self.__class__.__name__,
+ self._gca_resource.name,
+ worker,
+ uri,
+ ),
+ )
+ self._logged_web_access_uris.add(uri)
+
+ def _wait_callback(self):
+ if (
+ self._gca_resource.training_task_metadata
+ and self._gca_resource.training_task_metadata.get("backingCustomJob")
+ and not self._has_logged_custom_job
+ ):
+ _LOGGER.info(f"View backing custom job:\n{self._custom_job_console_uri()}")
+
+ if self._gca_resource.training_task_inputs.get("tensorboard"):
+ _LOGGER.info(f"View tensorboard:\n{self._tensorboard_console_uri()}")
+
+ self._has_logged_custom_job = True
+
+ if self._gca_resource.training_task_inputs.get("enable_web_access"):
+ self._log_web_access_uris()
+
+ def _custom_job_console_uri(self) -> str:
+ """Helper method to compose the dashboard uri where custom job can be viewed."""
+ custom_job_resource_name = self._gca_resource.training_task_metadata.get(
+ "backingCustomJob"
+ )
+ return console_utils.custom_job_console_uri(custom_job_resource_name)
+
+ def _tensorboard_console_uri(self) -> str:
+ """Helper method to compose dashboard uri where tensorboard can be viewed."""
+ tensorboard_resource_name = self._gca_resource.training_task_inputs.get(
+ "tensorboard"
+ )
+ custom_job_resource_name = self._gca_resource.training_task_metadata.get(
+ "backingCustomJob"
+ )
+ return console_utils.custom_job_tensorboard_console_uri(
+ tensorboard_resource_name, custom_job_resource_name
+ )
+
@property
def _model_upload_fail_string(self) -> str:
"""Helper property for model upload failure."""
@@ -1128,533 +1563,560 @@ def _model_upload_fail_string(self) -> str:
)
-# TODO(b/172368325) add scheduling, custom_job.Scheduling
-class CustomTrainingJob(_CustomTrainingJob):
- """Class to launch a Custom Training Job in Vertex AI using a script.
+class _ForecastingTrainingJob(_TrainingJob):
+ """ABC for Forecasting Training Pipelines."""
- Takes a training implementation as a python script and executes that
- script in Cloud Vertex AI Training.
- """
+ _supported_training_schemas = tuple()
def __init__(
self,
- display_name: str,
- script_path: str,
- container_uri: str,
- requirements: Optional[Sequence[str]] = None,
- model_serving_container_image_uri: Optional[str] = None,
- model_serving_container_predict_route: Optional[str] = None,
- model_serving_container_health_route: Optional[str] = None,
- model_serving_container_command: Optional[Sequence[str]] = None,
- model_serving_container_args: Optional[Sequence[str]] = None,
- model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
- model_serving_container_ports: Optional[Sequence[int]] = None,
- model_description: Optional[str] = None,
- model_instance_schema_uri: Optional[str] = None,
- model_parameters_schema_uri: Optional[str] = None,
- model_prediction_schema_uri: Optional[str] = None,
+ display_name: Optional[str] = None,
+ optimization_objective: Optional[str] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
- staging_bucket: Optional[str] = None,
):
- """Constructs a Custom Training Job from a Python script.
-
- job = aiplatform.CustomTrainingJob(
- display_name='test-train',
- script_path='test_script.py',
- requirements=['pandas', 'numpy'],
- container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest',
- model_serving_container_image_uri='gcr.io/my-trainer/serving:1',
- model_serving_container_predict_route='predict',
- model_serving_container_health_route='metadata)
-
- Usage with Dataset:
+ """Constructs a Forecasting Training Job.
- ds = aiplatform.TabularDataset(
- 'projects/my-project/locations/us-central1/datasets/12345')
-
- job.run(ds, replica_count=1, model_display_name='my-trained-model')
+ Args:
+ display_name (str):
+ Optional. The user-defined name of this TrainingPipeline.
+ optimization_objective (str):
+ Optional. Objective function the model is to be optimized towards.
+ The training process creates a Model that optimizes the value of the objective
+ function over the validation set. The supported optimization objectives:
+ "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
+ "minimize-mae" - Minimize mean-absolute error (MAE).
+ "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
+ "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE).
+ "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE)
+ and mean-absolute-error (MAE).
+ "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles.
+ (Set this objective to build quantile forecasts.)
+ column_specs (Dict[str, str]):
+ Optional. Alternative to column_transformations where the keys of the dict
+ are column names and their respective values are one of
+ AutoMLTabularTrainingJob.column_data_types.
+ When creating transformation for BigQuery Struct column, the column
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
+ If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have
+ no transformations defined on.
+ Only one of column_transformations or column_specs should be passed.
+ column_transformations (List[Dict[str, Dict[str, str]]]):
+ Optional. Transformations to apply to the input columns (i.e. columns other
+ than the targetColumn). Each transformation may produce multiple
+ result values from the column's value, and all are used for training.
+ When creating transformation for BigQuery Struct column, the column
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
+ If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have
+ no transformations defined on.
+ Only one of column_transformations or column_specs should be passed.
+ Consider using column_specs as column_transformations will be deprecated eventually.
+ project (str):
+ Optional. Project to run training in. Overrides project set in aiplatform.init.
+ location (str):
+ Optional. Location to run training in. Overrides location set in aiplatform.init.
+ credentials (auth_credentials.Credentials):
+ Optional. Custom credentials to use to run call training service. Overrides
+ credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ training_encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+ If set, this TrainingPipeline will be secured by this key.
+ Note: Model trained by this TrainingPipeline is also secured
+ by this key if ``model_to_upload`` is not set separately.
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ model_encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+ If set, the trained Model will be secured by this key.
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ Raises:
+ ValueError: If both column_transformations and column_specs were provided.
+ """
+ super().__init__(
+ display_name=display_name,
+ project=project,
+ location=location,
+ credentials=credentials,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
- Usage without Dataset:
+ self._column_transformations = (
+ column_transformations_utils.validate_and_get_column_transformations(
+ column_specs,
+ column_transformations,
+ )
+ )
- job.run(replica_count=1, model_display_name='my-trained-model)
+ self._optimization_objective = optimization_objective
+ self._additional_experiments = []
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _model_type(cls) -> str:
+ """The type of forecasting model."""
+ pass
- TODO(b/169782082) add documentation about traning utilities
- To ensure your model gets saved in Vertex AI, write your saved model to
- os.environ["AIP_MODEL_DIR"] in your provided training script.
+ @property
+ @classmethod
+ @abc.abstractmethod
+ def _training_task_definition(cls) -> str:
+ """A GCS path to the YAML file that defines the training task.
+ The definition files that can be used here are found in
+ gs://google-cloud-aiplatform/schema/trainingjob/definition/.
+ """
+ pass
- Args:
- display_name (str):
- Required. The user-defined name of this TrainingPipeline.
- script_path (str): Required. Local path to training script.
- container_uri (str):
- Required: Uri of the training container image in the GCR.
- requirements (Sequence[str]):
- List of python packages dependencies of script.
- model_serving_container_image_uri (str):
- If the training produces a managed Vertex AI Model, the URI of the
- Model serving container suitable for serving the model produced by the
- training script.
- model_serving_container_predict_route (str):
- If the training produces a managed Vertex AI Model, An HTTP path to
- send prediction requests to the container, and which must be supported
- by it. If not specified a default HTTP path will be used by Vertex AI.
- model_serving_container_health_route (str):
- If the training produces a managed Vertex AI Model, an HTTP path to
- send health check requests to the container, and which must be supported
- by it. If not specified a standard HTTP path will be used by AI
- Platform.
- model_serving_container_command (Sequence[str]):
- The command with which the container is run. Not executed within a
- shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- model_serving_container_args (Sequence[str]):
- The arguments to the command. The Docker image's CMD is used if this is
- not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
- model_serving_container_environment_variables (Dict[str, str]):
- The environment variables that are to be present in the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- model_serving_container_ports (Sequence[int]):
- Declaration of ports that are exposed by the container. This field is
- primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- model_description (str):
- The description of the Model.
- model_instance_schema_uri (str):
- Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single instance, which
- are used in
- ``PredictRequest.instances``,
- ``ExplainRequest.instances``
- and
- ``BatchPredictionJob.input_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object `__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- model_parameters_schema_uri (str):
- Optional. Points to a YAML file stored on Google Cloud
- Storage describing the parameters of prediction and
- explanation via
- ``PredictRequest.parameters``,
- ``ExplainRequest.parameters``
- and
- ``BatchPredictionJob.model_parameters``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object `__.
- AutoML Models always have this field populated by AI
- Platform, if no parameters are supported it is set to an
- empty string. Note: The URI given on output will be
- immutable and probably different, including the URI scheme,
- than the one given on input. The output URI will point to a
- location where the user only has a read access.
- model_prediction_schema_uri (str):
- Optional. Points to a YAML file stored on Google Cloud
- Storage describing the format of a single prediction
- produced by this Model, which are returned via
- ``PredictResponse.predictions``,
- ``ExplainResponse.explanations``,
- and
- ``BatchPredictionJob.output_config``.
- The schema is defined as an OpenAPI 3.0.2 `Schema
- Object `__.
- AutoML Models always have this field populated by AI
- Platform. Note: The URI given on output will be immutable
- and probably different, including the URI scheme, than the
- one given on input. The output URI will point to a location
- where the user only has a read access.
- project (str):
- Project to run training in. Overrides project set in aiplatform.init.
- location (str):
- Location to run training in. Overrides location set in aiplatform.init.
- credentials (auth_credentials.Credentials):
- Custom credentials to use to run call training service. Overrides
- credentials set in aiplatform.init.
- training_encryption_spec_key_name (Optional[str]):
- Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the training pipeline. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, this TrainingPipeline will be secured by this key.
-
- Note: Model trained by this TrainingPipeline is also secured
- by this key if ``model_to_upload`` is not set separately.
-
- Overrides encryption_spec_key_name set in aiplatform.init.
- model_encryption_spec_key_name (Optional[str]):
- Optional. The Cloud KMS resource identifier of the customer
- managed encryption key used to protect the model. Has the
- form:
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
- The key needs to be in the same region as where the compute
- resource is created.
-
- If set, the trained Model will be secured by this key.
-
- Overrides encryption_spec_key_name set in aiplatform.init.
- staging_bucket (str):
- Bucket used to stage source and training artifacts. Overrides
- staging_bucket set in aiplatform.init.
- """
- super().__init__(
- display_name=display_name,
- project=project,
- location=location,
- credentials=credentials,
- training_encryption_spec_key_name=training_encryption_spec_key_name,
- model_encryption_spec_key_name=model_encryption_spec_key_name,
- container_uri=container_uri,
- model_instance_schema_uri=model_instance_schema_uri,
- model_parameters_schema_uri=model_parameters_schema_uri,
- model_prediction_schema_uri=model_prediction_schema_uri,
- model_serving_container_environment_variables=model_serving_container_environment_variables,
- model_serving_container_ports=model_serving_container_ports,
- model_serving_container_image_uri=model_serving_container_image_uri,
- model_serving_container_command=model_serving_container_command,
- model_serving_container_args=model_serving_container_args,
- model_serving_container_predict_route=model_serving_container_predict_route,
- model_serving_container_health_route=model_serving_container_health_route,
- model_description=model_description,
- staging_bucket=staging_bucket,
- )
-
- self._requirements = requirements
- self._script_path = script_path
-
- # TODO(b/172365904) add filter split, training_pipeline.FilterSplit
- # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit
def run(
self,
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ] = None,
- annotation_schema_uri: Optional[str] = None,
- model_display_name: Optional[str] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 0,
- machine_type: str = "n1-standard-4",
- accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
- accelerator_count: int = 0,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ dataset: datasets.TimeSeriesDataset,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
predefined_split_column_name: Optional[str] = None,
- sync=True,
- ) -> Optional[models.Model]:
- """Runs the custom training job.
-
- Distributed Training Support:
- If replica count = 1 then one chief replica will be provisioned. If
- replica_count > 1 the remainder will be provisioned as a worker replica pool.
- ie: replica_count = 10 will result in 1 chief and 9 workers
- All replicas have same machine_type, accelerator_type, and accelerator_count
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ additional_experiments: Optional[List[str]] = None,
+ hierarchy_group_columns: Optional[List[str]] = None,
+ hierarchy_group_total_weight: Optional[float] = None,
+ hierarchy_temporal_total_weight: Optional[float] = None,
+ hierarchy_group_temporal_total_weight: Optional[float] = None,
+ window_column: Optional[str] = None,
+ window_stride_length: Optional[int] = None,
+ window_max_count: Optional[int] = None,
+ holiday_regions: Optional[List[str]] = None,
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> models.Model:
+ """Runs the training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI.If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
- dataset (
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ):
- Vertex AI to fit this training against. Custom training script should
- retrieve datasets through passed in environment variables uris:
+ dataset (datasets.TimeSeriesDataset):
+ Required. The dataset within the same Project from which data will be used to train the Model. The
+ Dataset must use schema compatible with Model being trained,
+ and what is compatible should be described in the used
+ TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
+ For time series Datasets, all their data is exported to
+ training, to pick and choose from.
+ target_column (str):
+ Required. Name of the column that the Model is to predict values for. This
+ column must be unavailable at forecast.
+ time_column (str):
+ Required. Name of the column that identifies time order in the time series.
+ This column must be available at forecast.
+ time_series_identifier_column (str):
+ Required. Name of the column that identifies the time series.
+ unavailable_at_forecast_columns (List[str]):
+ Required. Column names of columns that are unavailable at forecast.
+ Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is unknown before the forecast
+ (e.g. population of a city in a given year, or weather on a given day).
+ available_at_forecast_columns (List[str]):
+ Required. Column names of columns that are available at forecast.
+ Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is known at forecast.
+ forecast_horizon: (int):
+ Required. The amount of time into the future for which forecasted values for the target are
+ returned. Expressed in number of units defined by the [data_granularity_unit] and
+ [data_granularity_count] field. Inclusive.
+ data_granularity_unit (str):
+ Required. The data granularity unit. Accepted values are ``minute``,
+ ``hour``, ``day``, ``week``, ``month``, ``year``.
+ data_granularity_count (int):
+ Required. The number of data granularity units between data points in the training
+ data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other
+ values of [data_granularity_unit], must be 1.
+ predefined_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key (either the label's value or
+ value in the column) must be one of {``TRAIN``,
+ ``VALIDATE``, ``TEST``}, and it defines to which set the
+ given piece of data is assigned. If for a piece of data the
+ key is not present or has an invalid value, that piece is
+ ignored by the pipeline.
- os.environ["AIP_TRAINING_DATA_URI"]
- os.environ["AIP_VALIDATION_DATA_URI"]
- os.environ["AIP_TEST_DATA_URI"]
+ Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
+ weight_column (str):
+ Optional. Name of the column that should be used as the weight column.
+ Higher values in this column give more importance to the row
+ during Model training. The column must have numeric values between 0 and
+ 10000 inclusively, and 0 value means that the row is ignored.
+ If the weight column field is not set, then all rows are assumed to have
+ equal weight of 1. This column must be available at forecast.
+ time_series_attribute_columns (List[str]):
+ Optional. Column names that should be used as attribute columns.
+ Each column is constant within a time series.
+ context_window (int):
+ Optional. The amount of time into the past training and prediction data is used for
+ model training and prediction respectively. Expressed in number of units defined by the
+ [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the
+ default value of 0 which means the model sets each series context window to be 0 (also
+ known as "cold start"). Inclusive.
+ export_evaluated_data_items (bool):
+ Whether to export the test set predictions to a BigQuery table.
+ If False, then the export is not performed.
+ export_evaluated_data_items_bigquery_destination_uri (string):
+ Optional. URI of desired destination BigQuery table for exported test set predictions.
- Additionally the dataset format is passed in as:
+ Expected format:
+ ``bq://::``
- os.environ["AIP_DATA_FORMAT"]
- annotation_schema_uri (str):
- Google Cloud Storage URI points to a YAML file describing
- annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files
- that can be used here are found in
- gs://google-cloud-aiplatform/schema/dataset/annotation/,
- note that the chosen schema must be consistent with
- ``metadata``
- of the Dataset specified by
- ``dataset_id``.
+ If not specified, then results are exported to the following auto-created BigQuery
+ table:
+ ``:export_evaluated_examples__.evaluated_examples``
- Only Annotations that both match this schema and belong to
- DataItems not ignored by the split method are used in
- respectively training, validation or test role, depending on
- the role of the DataItem they are on.
+ Applies only if [export_evaluated_data_items] is True.
+ export_evaluated_data_items_override_destination (bool):
+ Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri],
+ if the table exists, for exported test set predictions. If False, and the
+ table exists, then the training job will fail.
- When used in conjunction with
- ``annotations_filter``,
- the Annotations used for training are filtered by both
- ``annotations_filter``
- and
- ``annotation_schema_uri``.
+ Applies only if [export_evaluated_data_items] is True and
+ [export_evaluated_data_items_bigquery_destination_uri] is specified.
+ quantiles (List[float]):
+ Quantiles to use for the `minimize-quantile-loss`
+ [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
+ this case.
+
+ Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
+ Each quantile must be unique.
+ validation_options (str):
+ Validation options for the data validation component. The available options are:
+ "fail-pipeline" - (default), will validate against the validation and fail the pipeline
+ if it fails.
+ "ignore-validation" - ignore the results of the validation and continue the pipeline
+ budget_milli_node_hours (int):
+ Optional. The train budget of creating this Model, expressed in milli node
+ hours i.e. 1,000 value in this field means 1 node hour.
+ The training cost of the model will not exceed this budget. The final
+ cost will be attempted to be close to the budget, though may end up
+ being (even) noticeably smaller - at the backend's discretion. This
+ especially may happen when further model training ceases to provide
+ any improvements.
+ If the budget is set to a value known to be insufficient to train a
+ Model for the given training set, the training won't be attempted and
+ will error.
+ The minimum value is 1000 and the maximum is 72000.
model_display_name (str):
- If the script produces a managed Vertex AI Model. The display name of
+ Optional. If the script produces a managed Vertex AI Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
- base_output_dir (str):
- GCS output directory of job. If not provided a
- timestamped directory in the staging directory will be used.
-
- Vertex AI sets the following environment variables when it runs your training code:
-
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. /model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. /checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/
-
- service_account (str):
- Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- network (str):
- The full name of the Compute Engine network to which the job
- should be peered. For example, projects/12345/global/networks/myVPC.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- bigquery_destination (str):
- Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset___``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
-
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
- args (List[Unions[str, int, float]]):
- Command line arguments to be passed to the Python script.
- environment_variables (Dict[str, str]):
- Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
-
- environment_variables = {
- 'MY_KEY': 'MY_VALUE'
- }
- replica_count (int):
- The number of worker replicas. If replica count = 1 then one chief
- replica will be provisioned. If replica_count > 1 the remainder will be
- provisioned as a worker replica pool.
- machine_type (str):
- The type of machine to use for training.
- accelerator_type (str):
- Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
- NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
- NVIDIA_TESLA_T4
- accelerator_count (int):
- The number of accelerators to attach to a worker replica.
- training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
- test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
- predefined_split_column_name (str):
- Optional. The key is a name of one of the Dataset's data
- columns. The value of the key (either the label's value or
- value in the column) must be one of {``training``,
- ``validation``, ``test``}, and it defines to which set the
- given piece of data is assigned. If for a piece of data the
- key is not present or has an invalid value, that piece is
- ignored by the pipeline.
-
- Supported only for tabular and time series Datasets.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ additional_experiments (List[str]):
+ Optional. Additional experiment flags for the time series forcasting training.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+ hierarchy_group_columns (List[str]):
+ Optional. A list of time series attribute column names that
+ define the time series hierarchy. Only one level of hierarchy is
+ supported, ex. ``region`` for a hierarchy of stores or
+ ``department`` for a hierarchy of products. If multiple columns
+ are specified, time series will be grouped by their combined
+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
+ to 5 columns are accepted. If no group columns are specified,
+ all time series are considered to be part of the same group.
+ hierarchy_group_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ time series in the same hierarchy group.
+ hierarchy_temporal_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ the horizon for a single time series.
+ hierarchy_group_temporal_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ both the horizon and time series in the same hierarchy group.
+ window_column (str):
+ Optional. Name of the column that should be used to filter input
+ rows. The column should contain either booleans or string
+ booleans; if the value of the row is True, generate a sliding
+ window from that row.
+ window_stride_length (int):
+ Optional. Step length used to generate input examples. Every
+ ``window_stride_length`` rows will be used to generate a sliding
+ window.
+ window_max_count (int):
+ Optional. Number of rows that should be used to generate input
+ examples. If the total row count is larger than this number, the
+ input data will be randomly sampled to hit the count.
+ holiday_regions (List[str]):
+ Optional. The geographical regions to use when creating holiday
+ features. This option is only allowed when data_granularity_unit
+ is ``day``. Acceptable values can come from any of the following
+ levels:
+ Top level: GLOBAL
+ Second level: continental regions
+ NA: North America
+ JAPAC: Japan and Asia Pacific
+ EMEA: Europe, the Middle East and Africa
+ LAC: Latin America and the Caribbean
+ Third level: countries from ISO 3166-1 Country codes.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
-
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
- """
- worker_pool_specs, managed_model = self._prepare_and_validate_run(
- model_display_name=model_display_name,
- replica_count=replica_count,
- machine_type=machine_type,
- accelerator_count=accelerator_count,
- accelerator_type=accelerator_type,
- )
+ produce a Vertex AI Model.
- # make and copy package
- python_packager = source_utils._TrainingScriptPythonPackager(
- script_path=self._script_path, requirements=self._requirements
- )
+ Raises:
+ RuntimeError: If Training job has already been run or is waiting to run.
+ """
- return self._run(
- python_packager=python_packager,
- dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
- worker_pool_specs=worker_pool_specs,
- managed_model=managed_model,
- args=args,
- environment_variables=environment_variables,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
- bigquery_destination=bigquery_destination,
- training_fraction_split=training_fraction_split,
- validation_fraction_split=validation_fraction_split,
- test_fraction_split=test_fraction_split,
- predefined_split_column_name=predefined_split_column_name,
- sync=sync,
- )
+ if model_display_name:
+ utils.validate_display_name(model_display_name)
+ if model_labels:
+ utils.validate_labels(model_labels)
- @base.optional_sync(construct_object_on_arg="managed_model")
- def _run(
- self,
- python_packager: source_utils._TrainingScriptPythonPackager,
- dataset: Optional[
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ],
- annotation_schema_uri: Optional[str],
- worker_pool_specs: worker_spec_utils._DistributedTrainingSpec,
- managed_model: Optional[gca_model.Model] = None,
- args: Optional[List[Union[str, float, int]]] = None,
- environment_variables: Optional[Dict[str, str]] = None,
- base_output_dir: Optional[str] = None,
- service_account: Optional[str] = None,
- network: Optional[str] = None,
- bigquery_destination: Optional[str] = None,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
- predefined_split_column_name: Optional[str] = None,
- sync=True,
- ) -> Optional[models.Model]:
- """Packages local script and launches training_job.
+ if self._is_waiting_to_run():
+ raise RuntimeError(
+ f"{self._model_type} Forecasting Training is already scheduled "
+ "to run."
+ )
- Args:
- python_packager (source_utils._TrainingScriptPythonPackager):
- Required. Python Packager pointing to training script locally.
- dataset (
- Union[
- datasets.ImageDataset,
- datasets.TabularDataset,
- datasets.TextDataset,
- datasets.VideoDataset,
- ]
- ):
- Vertex AI to fit this training against.
- annotation_schema_uri (str):
- Google Cloud Storage URI points to a YAML file describing
- annotation schema.
- worker_pools_spec (worker_spec_utils._DistributedTrainingSpec):
- Worker pools pecs required to run job.
- managed_model (gca_model.Model):
- Model proto if this script produces a Managed Model.
- args (List[Unions[str, int, float]]):
- Command line arguments to be passed to the Python script.
- environment_variables (Dict[str, str]):
- Environment variables to be passed to the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- At most 10 environment variables can be specified.
- The Name of the environment variable must be unique.
+ if self._has_run:
+ raise RuntimeError(
+ f"{self._model_type} Forecasting Training has already run."
+ )
- environment_variables = {
- 'MY_KEY': 'MY_VALUE'
- }
- base_output_dir (str):
- GCS output directory of job. If not provided a
- timestamped directory in the staging directory will be used.
+ if additional_experiments:
+ self._add_additional_experiments(additional_experiments)
- Vertex AI sets the following environment variables when it runs your training code:
+ return self._run(
+ dataset=dataset,
+ target_column=target_column,
+ time_column=time_column,
+ time_series_identifier_column=time_series_identifier_column,
+ unavailable_at_forecast_columns=unavailable_at_forecast_columns,
+ available_at_forecast_columns=available_at_forecast_columns,
+ forecast_horizon=forecast_horizon,
+ data_granularity_unit=data_granularity_unit,
+ data_granularity_count=data_granularity_count,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
+ weight_column=weight_column,
+ time_series_attribute_columns=time_series_attribute_columns,
+ context_window=context_window,
+ budget_milli_node_hours=budget_milli_node_hours,
+ export_evaluated_data_items=export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
+ export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
+ quantiles=quantiles,
+ validation_options=validation_options,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ hierarchy_group_columns=hierarchy_group_columns,
+ hierarchy_group_total_weight=hierarchy_group_total_weight,
+ hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
+ hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
+ window_column=window_column,
+ window_stride_length=window_stride_length,
+ window_max_count=window_max_count,
+ holiday_regions=holiday_regions,
+ sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
- - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. /model/
- - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. /checkpoints/
- - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/
+ @base.optional_sync()
+ def _run(
+ self,
+ dataset: datasets.TimeSeriesDataset,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ hierarchy_group_columns: Optional[List[str]] = None,
+ hierarchy_group_total_weight: Optional[float] = None,
+ hierarchy_temporal_total_weight: Optional[float] = None,
+ hierarchy_group_temporal_total_weight: Optional[float] = None,
+ window_column: Optional[str] = None,
+ window_stride_length: Optional[int] = None,
+ window_max_count: Optional[int] = None,
+ holiday_regions: Optional[List[str]] = None,
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> models.Model:
+ """Runs the training job and returns a model.
- service_account (str):
- Specifies the service account for workload run-as account.
- Users submitting jobs must have act-as permission on this run-as account.
- network (str):
- The full name of the Compute Engine network to which the job
- should be peered. For example, projects/12345/global/networks/myVPC.
- Private services access must already be configured for the network.
- If left unspecified, the job is not peered with any network.
- bigquery_destination (str):
- Provide this field if `dataset` is a BiqQuery dataset.
- The BigQuery project location where the training data is to
- be written to. In the given project a new dataset is created
- with name
- ``dataset___``
- where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
- training input data will be written into that dataset. In
- the dataset three tables will be created, ``training``,
- ``validation`` and ``test``.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
- - AIP_DATA_FORMAT = "bigquery".
- - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
+ Args:
+ dataset (datasets.TimeSeriesDataset):
+ Required. The dataset within the same Project from which data will be used to train the Model. The
+ Dataset must use schema compatible with Model being trained,
+ and what is compatible should be described in the used
+ TrainingPipeline's [training_task_definition]
+ [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
+ For time series Datasets, all their data is exported to
+ training, to pick and choose from.
+ target_column (str):
+ Required. Name of the column that the Model is to predict values for. This
+ column must be unavailable at forecast.
+ time_column (str):
+ Required. Name of the column that identifies time order in the time series.
+ This column must be available at forecast.
+ time_series_identifier_column (str):
+ Required. Name of the column that identifies the time series.
+ unavailable_at_forecast_columns (List[str]):
+ Required. Column names of columns that are unavailable at forecast.
+ Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is unknown before the forecast
+ (e.g. population of a city in a given year, or weather on a given day).
+ available_at_forecast_columns (List[str]):
+ Required. Column names of columns that are available at forecast.
+ Each column contains information for the given entity (identified by the
+ [time_series_identifier_column]) that is known at forecast.
+ forecast_horizon: (int):
+ Required. The amount of time into the future for which forecasted values for the target are
+ returned. Expressed in number of units defined by the [data_granularity_unit] and
+ [data_granularity_count] field. Inclusive.
+ data_granularity_unit (str):
+ Required. The data granularity unit. Accepted values are ``minute``,
+ ``hour``, ``day``, ``week``, ``month``, ``year``.
+ data_granularity_count (int):
+ Required. The number of data granularity units between data points in the training
+ data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other
+ values of [data_granularity_unit], must be 1.
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -1665,73 +2127,326 @@ def _run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
+ weight_column (str):
+ Optional. Name of the column that should be used as the weight column.
+ Higher values in this column give more importance to the row
+ during Model training. The column must have numeric values between 0 and
+ 10000 inclusively, and 0 value means that the row is ignored.
+ If the weight column field is not set, then all rows are assumed to have
+ equal weight of 1. This column must be available at forecast.
+ time_series_attribute_columns (List[str]):
+ Optional. Column names that should be used as attribute columns.
+ Each column is constant within a time series.
+ context_window (int):
+ Optional. The amount of time into the past training and prediction data is used for
+ model training and prediction respectively. Expressed in number of units defined by the
+ [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the
+ default value of 0 which means the model sets each series context window to be 0 (also
+ known as "cold start"). Inclusive.
+ export_evaluated_data_items (bool):
+ Whether to export the test set predictions to a BigQuery table.
+ If False, then the export is not performed.
+ export_evaluated_data_items_bigquery_destination_uri (string):
+ Optional. URI of desired destination BigQuery table for exported test set predictions.
+
+ Expected format:
+ ``bq://::``
+
+ If not specified, then results are exported to the following auto-created BigQuery
+ table:
+ ``:export_evaluated_examples__.evaluated_examples``
+
+ Applies only if [export_evaluated_data_items] is True.
+ export_evaluated_data_items_override_destination (bool):
+ Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri],
+ if the table exists, for exported test set predictions. If False, and the
+ table exists, then the training job will fail.
+
+ Applies only if [export_evaluated_data_items] is True and
+ [export_evaluated_data_items_bigquery_destination_uri] is specified.
+ quantiles (List[float]):
+ Quantiles to use for the `minimize-quantile-loss`
+ [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
+ this case.
+
+ Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
+ Each quantile must be unique.
+ validation_options (str):
+ Validation options for the data validation component. The available options are:
+ "fail-pipeline" - (default), will validate against the validation and fail the pipeline
+ if it fails.
+ "ignore-validation" - ignore the results of the validation and continue the pipeline
+ budget_milli_node_hours (int):
+ Optional. The train budget of creating this Model, expressed in milli node
+ hours i.e. 1,000 value in this field means 1 node hour.
+ The training cost of the model will not exceed this budget. The final
+ cost will be attempted to be close to the budget, though may end up
+ being (even) noticeably smaller - at the backend's discretion. This
+ especially may happen when further model training ceases to provide
+ any improvements.
+ If the budget is set to a value known to be insufficient to train a
+ Model for the given training set, the training won't be attempted and
+ will error.
+ The minimum value is 1000 and the maximum is 72000.
+ model_display_name (str):
+ Optional. If the script produces a managed Vertex AI Model. The display name of
+ the Model. The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+
+ If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ hierarchy_group_columns (List[str]):
+ Optional. A list of time series attribute column names that
+ define the time series hierarchy. Only one level of hierarchy is
+ supported, ex. ``region`` for a hierarchy of stores or
+ ``department`` for a hierarchy of products. If multiple columns
+ are specified, time series will be grouped by their combined
+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
+ to 5 columns are accepted. If no group columns are specified,
+ all time series are considered to be part of the same group.
+ hierarchy_group_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ time series in the same hierarchy group.
+ hierarchy_temporal_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ the horizon for a single time series.
+ hierarchy_group_temporal_total_weight (float):
+ Optional. The weight of the loss for predictions aggregated over
+ both the horizon and time series in the same hierarchy group.
+ window_column (str):
+ Optional. Name of the column that should be used to filter input
+ rows. The column should contain either booleans or string
+ booleans; if the value of the row is True, generate a sliding
+ window from that row.
+ window_stride_length (int):
+ Optional. Step length used to generate input examples. Every
+ ``window_stride_length`` rows will be used to generate a sliding
+ window.
+ window_max_count (int):
+ Optional. Number of rows that should be used to generate input
+ examples. If the total row count is larger than this number, the
+ input data will be randomly sampled to hit the count.
+ holiday_regions (List[str]):
+ Optional. The geographical regions to use when creating holiday
+ features. This option is only allowed when data_granularity_unit
+ is ``day``. Acceptable values can come from any of the following
+ levels:
+ Top level: GLOBAL
+ Second level: continental regions
+ NA: North America
+ JAPAC: Japan and Asia Pacific
+ EMEA: Europe, the Middle East and Africa
+ LAC: Latin America and the Caribbean
+ Third level: countries from ISO 3166-1 Country codes.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
-
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
- package_gcs_uri = python_packager.package_and_copy_to_gcs(
- gcs_staging_dir=self._staging_bucket,
- project=self.project,
- credentials=self.credentials,
+ # auto-populate transformations
+ if self._column_transformations is None:
+ _LOGGER.info(
+ "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
+ )
+
+ (
+ self._column_transformations,
+ column_names,
+ ) = dataset._get_default_column_transformations(target_column)
+
+ _LOGGER.info(
+ "The column transformation of type 'auto' was set for the following columns: %s."
+ % column_names
+ )
+
+ window_config = self._create_window_config(
+ column=window_column,
+ stride_length=window_stride_length,
+ max_count=window_max_count,
)
- for spec in worker_pool_specs:
- spec["python_package_spec"] = {
- "executor_image_uri": self._container_uri,
- "python_module": python_packager.module_name,
- "package_uris": [package_gcs_uri],
+ training_task_inputs_dict = {
+ # required inputs
+ "targetColumn": target_column,
+ "timeColumn": time_column,
+ "timeSeriesIdentifierColumn": time_series_identifier_column,
+ "timeSeriesAttributeColumns": time_series_attribute_columns,
+ "unavailableAtForecastColumns": unavailable_at_forecast_columns,
+ "availableAtForecastColumns": available_at_forecast_columns,
+ "forecastHorizon": forecast_horizon,
+ "dataGranularity": {
+ "unit": data_granularity_unit,
+ "quantity": data_granularity_count,
+ },
+ "transformations": self._column_transformations,
+ "trainBudgetMilliNodeHours": budget_milli_node_hours,
+ # optional inputs
+ "weightColumn": weight_column,
+ "contextWindow": context_window,
+ "quantiles": quantiles,
+ "validationOptions": validation_options,
+ "optimizationObjective": self._optimization_objective,
+ "holidayRegions": holiday_regions,
+ }
+
+ # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
+ if any(
+ [
+ hierarchy_group_columns,
+ hierarchy_group_total_weight,
+ hierarchy_temporal_total_weight,
+ hierarchy_group_temporal_total_weight,
+ ]
+ ):
+ training_task_inputs_dict["hierarchyConfig"] = {
+ "groupColumns": hierarchy_group_columns,
+ "groupTotalWeight": hierarchy_group_total_weight,
+ "temporalTotalWeight": hierarchy_temporal_total_weight,
+ "groupTemporalTotalWeight": hierarchy_group_temporal_total_weight,
}
+ if window_config:
+ training_task_inputs_dict["windowConfig"] = window_config
- if args:
- spec["python_package_spec"]["args"] = args
+ final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
+ if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
+ "bq://"
+ ):
+ final_export_eval_bq_uri = f"bq://{final_export_eval_bq_uri}"
- if environment_variables:
- spec["python_package_spec"]["env"] = [
- {"name": key, "value": value}
- for key, value in environment_variables.items()
- ]
+ if export_evaluated_data_items:
+ training_task_inputs_dict["exportEvaluatedDataItemsConfig"] = {
+ "destinationBigqueryUri": final_export_eval_bq_uri,
+ "overrideExistingTable": export_evaluated_data_items_override_destination,
+ }
- (
- training_task_inputs,
- base_output_dir,
- ) = self._prepare_training_task_inputs_and_output_dir(
- worker_pool_specs=worker_pool_specs,
- base_output_dir=base_output_dir,
- service_account=service_account,
- network=network,
+ if self._additional_experiments:
+ training_task_inputs_dict[
+ "additionalExperiments"
+ ] = self._additional_experiments
+
+ model = gca_model.Model(
+ display_name=model_display_name or self._display_name,
+ labels=model_labels or self._labels,
+ encryption_spec=self._model_encryption_spec,
)
- model = self._run_job(
- training_task_definition=schema.training_job.definition.custom_task,
- training_task_inputs=training_task_inputs,
+ new_model = self._run_job(
+ training_task_definition=self._training_task_definition,
+ training_task_inputs=training_task_inputs_dict,
dataset=dataset,
- annotation_schema_uri=annotation_schema_uri,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
predefined_split_column_name=predefined_split_column_name,
- model=managed_model,
- gcs_destination_uri_prefix=base_output_dir,
- bigquery_destination=bigquery_destination,
+ timestamp_split_column_name=timestamp_split_column_name,
+ model=model,
+ create_request_timeout=create_request_timeout,
)
- return model
+ if export_evaluated_data_items:
+ _LOGGER.info(
+ "Exported examples available at:\n%s"
+ % self.evaluated_data_items_bigquery_uri
+ )
+ return new_model
-class CustomContainerTrainingJob(_CustomTrainingJob):
- """Class to launch a Custom Training Job in Vertex AI using a
- Container."""
+ @property
+ def _model_upload_fail_string(self) -> str:
+ """Helper property for model upload failure."""
+ return (
+ f"Training Pipeline {self.resource_name} is not configured to upload a "
+ "Model."
+ )
+
+ @property
+ def evaluated_data_items_bigquery_uri(self) -> Optional[str]:
+ """BigQuery location of exported evaluated examples from the Training Job
+ Returns:
+ str: BigQuery uri for the exported evaluated examples if the export
+ feature is enabled for training.
+ None: If the export feature was not enabled for training.
+ """
+
+ self._assert_gca_resource_is_available()
+
+ metadata = self._gca_resource.training_task_metadata
+ if metadata and "evaluatedDataItemsBigqueryUri" in metadata:
+ return metadata["evaluatedDataItemsBigqueryUri"]
+
+ return None
+
+ def _add_additional_experiments(self, additional_experiments: List[str]):
+ """Add experiment flags to the training job.
+ Args:
+ additional_experiments (List[str]):
+ Experiment flags that can enable some experimental training features.
+ """
+ self._additional_experiments.extend(additional_experiments)
+
+ @staticmethod
+ def _create_window_config(
+ column: Optional[str] = None,
+ stride_length: Optional[int] = None,
+ max_count: Optional[int] = None,
+ ) -> Optional[Dict[str, Union[int, str]]]:
+ """Creates a window config from training job arguments."""
+ configs = {
+ "column": column,
+ "strideLength": stride_length,
+ "maxCount": max_count,
+ }
+ present_configs = {k: v for k, v in configs.items() if v is not None}
+ if not present_configs:
+ return None
+ if len(present_configs) > 1:
+ raise ValueError(
+ "More than one windowing strategy provided. Make sure only one "
+ "of window_column, window_stride_length, or window_max_count "
+ "is specified."
+ )
+ return present_configs
+
+
+# TODO(b/172368325) add scheduling, custom_job.Scheduling
+class CustomTrainingJob(_CustomTrainingJob):
+ """Class to launch a Custom Training Job in Vertex AI using a script.
+
+ Takes a training implementation as a python script and executes that
+ script in Cloud Vertex AI Training.
+ """
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
+ script_path: str,
container_uri: str,
- command: Sequence[str] = None,
+ requirements: Optional[Sequence[str]] = None,
model_serving_container_image_uri: Optional[str] = None,
model_serving_container_predict_route: Optional[str] = None,
model_serving_container_health_route: Optional[str] = None,
@@ -1746,26 +2461,35 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
):
- """Constructs a Custom Container Training Job.
+ """Constructs a Custom Training Job from a Python script.
job = aiplatform.CustomTrainingJob(
display_name='test-train',
+ script_path='test_script.py',
+ requirements=['pandas', 'numpy'],
container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest',
- command=['python3', 'run_script.py']
model_serving_container_image_uri='gcr.io/my-trainer/serving:1',
model_serving_container_predict_route='predict',
- model_serving_container_health_route='metadata)
+ model_serving_container_health_route='metadata,
+ labels={'key': 'value'},
+ )
Usage with Dataset:
ds = aiplatform.TabularDataset(
'projects/my-project/locations/us-central1/datasets/12345')
- job.run(ds, replica_count=1, model_display_name='my-trained-model')
+ job.run(
+ ds,
+ replica_count=1,
+ model_display_name='my-trained-model',
+ model_labels={'key': 'value'},
+ )
Usage without Dataset:
@@ -1780,11 +2504,11 @@ def __init__(
Args:
display_name (str):
Required. The user-defined name of this TrainingPipeline.
+ script_path (str): Required. Local path to training script.
container_uri (str):
Required: Uri of the training container image in the GCR.
- command (Sequence[str]):
- The command to be invoked when the container is started.
- It overrides the entrypoint instruction in Dockerfile when provided
+ requirements (Sequence[str]):
+ List of python packages dependencies of script.
model_serving_container_image_uri (str):
If the training produces a managed Vertex AI Model, the URI of the
Model serving container suitable for serving the model produced by the
@@ -1879,6 +2603,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -1908,11 +2642,14 @@ def __init__(
Bucket used to stage source and training artifacts. Overrides
staging_bucket set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
container_uri=container_uri,
@@ -1930,10 +2667,9 @@ def __init__(
staging_bucket=staging_bucket,
)
- self._command = command
+ self._requirements = requirements
+ self._script_path = script_path
- # TODO(b/172365904) add filter split, training_pipeline.FilterSplit
- # TODO(b/172368070) add timestamp split, training_pipeline.TimestampSplit
def run(
self,
dataset: Optional[
@@ -1946,21 +2682,36 @@ def run(
] = None,
annotation_schema_uri: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 0,
+ replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
@@ -1970,15 +2721,46 @@ def run(
ie: replica_count = 10 will result in 1 chief and 9 workers
All replicas have same machine_type, accelerator_type, and accelerator_count
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
- dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset]):
+ dataset (
+ Union[
+ datasets.ImageDataset,
+ datasets.TabularDataset,
+ datasets.TextDataset,
+ datasets.VideoDataset,
+ ]
+ ):
Vertex AI to fit this training against. Custom training script should
retrieve datasets through passed in environment variables uris:
@@ -1992,7 +2774,7 @@ def run(
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files
+ [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
@@ -2017,6 +2799,16 @@ def run(
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
@@ -2074,15 +2866,50 @@ def run(
NVIDIA_TESLA_T4
accelerator_count (int):
The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Type of the boot disk, default is `pd-ssd`.
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Size in GB of the boot disk, default is 100GB.
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ Optional. The type of machine to use for reduction server.
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -2093,6 +2920,43 @@ def run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
@@ -2100,22 +2964,28 @@ def run(
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
-
- Raises:
- RuntimeError: If Training job has already been run, staging_bucket has not
- been set, or model_display_name was provided but required arguments
- were not provided in constructor.
+ produce a Vertex AI Model.
"""
worker_pool_specs, managed_model = self._prepare_and_validate_run(
model_display_name=model_display_name,
+ model_labels=model_labels,
replica_count=replica_count,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
+ reduction_server_replica_count=reduction_server_replica_count,
+ reduction_server_machine_type=reduction_server_machine_type,
+ )
+
+ # make and copy package
+ python_packager = source_utils._TrainingScriptPythonPackager(
+ script_path=self._script_path, requirements=self._requirements
)
return self._run(
+ python_packager=python_packager,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
worker_pool_specs=worker_pool_specs,
@@ -2129,13 +2999,26 @@ def run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
+ reduction_server_container_uri=reduction_server_container_uri
+ if reduction_server_replica_count > 0
+ else None,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync(construct_object_on_arg="managed_model")
def _run(
self,
+ python_packager: source_utils._TrainingScriptPythonPackager,
dataset: Optional[
Union[
datasets.ImageDataset,
@@ -2153,14 +3036,27 @@ def _run(
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
+
Args:
+ python_packager (source_utils._TrainingScriptPythonPackager):
+ Required. Python Packager pointing to training script locally.
dataset (
Union[
datasets.ImageDataset,
@@ -2208,6 +3104,7 @@ def _run(
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
bigquery_destination (str):
+ Provide this field if `dataset` is a BiqQuery dataset.
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
with name
@@ -2222,14 +3119,35 @@ def _run(
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -2240,30 +3158,87 @@ def _run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float)
+ Optional. The timeout for the create request in seconds
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
+ package_gcs_uri = python_packager.package_and_copy_to_gcs(
+ gcs_staging_dir=self._staging_bucket,
+ project=self.project,
+ credentials=self.credentials,
+ )
- for spec in worker_pool_specs:
- spec["containerSpec"] = {"imageUri": self._container_uri}
+ for spec_order, spec in enumerate(worker_pool_specs):
- if self._command:
- spec["containerSpec"]["command"] = self._command
+ if not spec:
+ continue
- if args:
- spec["containerSpec"]["args"] = args
+ if (
+ spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
+ and reduction_server_container_uri
+ ):
+ spec["container_spec"] = {
+ "image_uri": reduction_server_container_uri,
+ }
+ else:
+ spec["python_package_spec"] = {
+ "executor_image_uri": self._container_uri,
+ "python_module": python_packager.module_name,
+ "package_uris": [package_gcs_uri],
+ }
- if environment_variables:
- spec["containerSpec"]["env"] = [
- {"name": key, "value": value}
- for key, value in environment_variables.items()
- ]
+ if args:
+ spec["python_package_spec"]["args"] = args
+
+ if environment_variables:
+ spec["python_package_spec"]["env"] = [
+ {"name": key, "value": value}
+ for key, value in environment_variables.items()
+ ]
(
training_task_inputs,
@@ -2273,6 +3248,10 @@ def _run(
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
)
model = self._run_job(
@@ -2283,101 +3262,195 @@ def _run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
model=managed_model,
gcs_destination_uri_prefix=base_output_dir,
bigquery_destination=bigquery_destination,
+ create_request_timeout=create_request_timeout,
)
return model
-class AutoMLTabularTrainingJob(_TrainingJob):
- _supported_training_schemas = (schema.training_job.definition.automl_tabular,)
+class CustomContainerTrainingJob(_CustomTrainingJob):
+ """Class to launch a Custom Training Job in Vertex AI using a
+ Container."""
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
- optimization_prediction_type: str,
- optimization_objective: Optional[str] = None,
- column_transformations: Optional[Union[Dict, List[Dict]]] = None,
- optimization_objective_recall_value: Optional[float] = None,
- optimization_objective_precision_value: Optional[float] = None,
+ container_uri: str,
+ command: Sequence[str] = None,
+ model_serving_container_image_uri: Optional[str] = None,
+ model_serving_container_predict_route: Optional[str] = None,
+ model_serving_container_health_route: Optional[str] = None,
+ model_serving_container_command: Optional[Sequence[str]] = None,
+ model_serving_container_args: Optional[Sequence[str]] = None,
+ model_serving_container_environment_variables: Optional[Dict[str, str]] = None,
+ model_serving_container_ports: Optional[Sequence[int]] = None,
+ model_description: Optional[str] = None,
+ model_instance_schema_uri: Optional[str] = None,
+ model_parameters_schema_uri: Optional[str] = None,
+ model_prediction_schema_uri: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
+ staging_bucket: Optional[str] = None,
):
- """Constructs a AutoML Tabular Training Job.
+ """Constructs a Custom Container Training Job.
- Args:
- display_name (str):
- Required. The user-defined name of this TrainingPipeline.
- optimization_prediction_type (str):
- The type of prediction the Model is to produce.
- "classification" - Predict one out of multiple target values is
- picked for each row.
- "regression" - Predict a value based on its relation to other values.
- This type is available only to columns that contain
- semantically numeric values, i.e. integers or floating
- point number, even if stored as e.g. strings.
+ job = aiplatform.CustomContainerTrainingJob(
+ display_name='test-train',
+ container_uri='gcr.io/my_project_id/my_image_name:tag',
+ command=['python3', 'run_script.py']
+ model_serving_container_image_uri='gcr.io/my-trainer/serving:1',
+ model_serving_container_predict_route='predict',
+ model_serving_container_health_route='metadata,
+ labels={'key': 'value'},
+ )
- optimization_objective (str):
- Optional. Objective function the Model is to be optimized towards. The training
- task creates a Model that maximizes/minimizes the value of the objective
- function over the validation set.
+ Usage with Dataset:
- The supported optimization objectives depend on the prediction type, and
- in the case of classification also the number of distinct values in the
- target column (two distint values -> binary, 3 or more distinct values
- -> multi class).
- If the field is not set, the default objective function is used.
+ ds = aiplatform.TabularDataset(
+ 'projects/my-project/locations/us-central1/datasets/12345')
- Classification (binary):
- "maximize-au-roc" (default) - Maximize the area under the receiver
- operating characteristic (ROC) curve.
- "minimize-log-loss" - Minimize log loss.
- "maximize-au-prc" - Maximize the area under the precision-recall curve.
- "maximize-precision-at-recall" - Maximize precision for a specified
- recall value.
- "maximize-recall-at-precision" - Maximize recall for a specified
- precision value.
+ job.run(
+ ds,
+ replica_count=1,
+ model_display_name='my-trained-model',
+ model_labels={'key': 'value'},
+ )
- Classification (multi class):
- "minimize-log-loss" (default) - Minimize log loss.
+ Usage without Dataset:
- Regression:
- "minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
- "minimize-mae" - Minimize mean-absolute error (MAE).
- "minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
- column_transformations (Optional[Union[Dict, List[Dict]]]):
- Optional. Transformations to apply to the input columns (i.e. columns other
- than the targetColumn). Each transformation may produce multiple
- result values from the column's value, and all are used for training.
- When creating transformation for BigQuery Struct column, the column
- should be flattened using "." as the delimiter.
- If an input column has no transformations on it, such a column is
- ignored by the training, except for the targetColumn, which should have
- no transformations defined on.
- optimization_objective_recall_value (float):
- Optional. Required when maximize-precision-at-recall optimizationObjective was
- picked, represents the recall value at which the optimization is done.
+ job.run(replica_count=1, model_display_name='my-trained-model)
- The minimum value is 0 and the maximum is 1.0.
- optimization_objective_precision_value (float):
- Optional. Required when maximize-recall-at-precision optimizationObjective was
- picked, represents the precision value at which the optimization is
- done.
- The minimum value is 0 and the maximum is 1.0.
+ TODO(b/169782082) add documentation about traning utilities
+ To ensure your model gets saved in Vertex AI, write your saved model to
+ os.environ["AIP_MODEL_DIR"] in your provided training script.
+
+
+ Args:
+ display_name (str):
+ Required. The user-defined name of this TrainingPipeline.
+ container_uri (str):
+ Required: Uri of the training container image in the GCR.
+ command (Sequence[str]):
+ The command to be invoked when the container is started.
+ It overrides the entrypoint instruction in Dockerfile when provided
+ model_serving_container_image_uri (str):
+ If the training produces a managed Vertex AI Model, the URI of the
+ Model serving container suitable for serving the model produced by the
+ training script.
+ model_serving_container_predict_route (str):
+ If the training produces a managed Vertex AI Model, An HTTP path to
+ send prediction requests to the container, and which must be supported
+ by it. If not specified a default HTTP path will be used by Vertex AI.
+ model_serving_container_health_route (str):
+ If the training produces a managed Vertex AI Model, an HTTP path to
+ send health check requests to the container, and which must be supported
+ by it. If not specified a standard HTTP path will be used by AI
+ Platform.
+ model_serving_container_command (Sequence[str]):
+ The command with which the container is run. Not executed within a
+ shell. The Docker image's ENTRYPOINT is used if this is not provided.
+ Variable references $(VAR_NAME) are expanded using the container's
+ environment. If a variable cannot be resolved, the reference in the
+ input string will be unchanged. The $(VAR_NAME) syntax can be escaped
+ with a double $$, ie: $$(VAR_NAME). Escaped references will never be
+ expanded, regardless of whether the variable exists or not.
+ model_serving_container_args (Sequence[str]):
+ The arguments to the command. The Docker image's CMD is used if this is
+ not provided. Variable references $(VAR_NAME) are expanded using the
+ container's environment. If a variable cannot be resolved, the reference
+ in the input string will be unchanged. The $(VAR_NAME) syntax can be
+ escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
+ never be expanded, regardless of whether the variable exists or not.
+ model_serving_container_environment_variables (Dict[str, str]):
+ The environment variables that are to be present in the container.
+ Should be a dictionary where keys are environment variable names
+ and values are environment variable values for those names.
+ model_serving_container_ports (Sequence[int]):
+ Declaration of ports that are exposed by the container. This field is
+ primarily informational, it gives Vertex AI information about the
+ network connections the container uses. Listing or not a port here has
+ no impact on whether the port is actually exposed, any port listening on
+ the default "0.0.0.0" address inside a container will be accessible from
+ the network.
+ model_description (str):
+ The description of the Model.
+ model_instance_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single instance, which
+ are used in
+ ``PredictRequest.instances``,
+ ``ExplainRequest.instances``
+ and
+ ``BatchPredictionJob.input_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
+ model_parameters_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the parameters of prediction and
+ explanation via
+ ``PredictRequest.parameters``,
+ ``ExplainRequest.parameters``
+ and
+ ``BatchPredictionJob.model_parameters``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform, if no parameters are supported it is set to an
+ empty string. Note: The URI given on output will be
+ immutable and probably different, including the URI scheme,
+ than the one given on input. The output URI will point to a
+ location where the user only has a read access.
+ model_prediction_schema_uri (str):
+ Optional. Points to a YAML file stored on Google Cloud
+ Storage describing the format of a single prediction
+ produced by this Model, which are returned via
+ ``PredictResponse.predictions``,
+ ``ExplainResponse.explanations``,
+ and
+ ``BatchPredictionJob.output_config``.
+ The schema is defined as an OpenAPI 3.0.2 `Schema
+ Object `__.
+ AutoML Models always have this field populated by AI
+ Platform. Note: The URI given on output will be immutable
+ and probably different, including the URI scheme, than the
+ one given on input. The output URI will point to a location
+ where the user only has a read access.
project (str):
- Optional. Project to run training in. Overrides project set in aiplatform.init.
+ Project to run training in. Overrides project set in aiplatform.init.
location (str):
- Optional. Location to run training in. Overrides location set in aiplatform.init.
+ Location to run training in. Overrides location set in aiplatform.init.
credentials (auth_credentials.Credentials):
- Optional. Custom credentials to use to run call training service. Overrides
+ Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -2403,64 +3476,270 @@ def __init__(
If set, the trained Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
+ staging_bucket (str):
+ Bucket used to stage source and training artifacts. Overrides
+ staging_bucket set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
- )
- self._column_transformations = column_transformations
- self._optimization_objective = optimization_objective
- self._optimization_prediction_type = optimization_prediction_type
- self._optimization_objective_recall_value = optimization_objective_recall_value
- self._optimization_objective_precision_value = (
- optimization_objective_precision_value
+ container_uri=container_uri,
+ model_instance_schema_uri=model_instance_schema_uri,
+ model_parameters_schema_uri=model_parameters_schema_uri,
+ model_prediction_schema_uri=model_prediction_schema_uri,
+ model_serving_container_environment_variables=model_serving_container_environment_variables,
+ model_serving_container_ports=model_serving_container_ports,
+ model_serving_container_image_uri=model_serving_container_image_uri,
+ model_serving_container_command=model_serving_container_command,
+ model_serving_container_args=model_serving_container_args,
+ model_serving_container_predict_route=model_serving_container_predict_route,
+ model_serving_container_health_route=model_serving_container_health_route,
+ model_description=model_description,
+ staging_bucket=staging_bucket,
)
+ self._command = command
+
def run(
self,
- dataset: datasets.TabularDataset,
- target_column: str,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
- predefined_split_column_name: Optional[str] = None,
- weight_column: Optional[str] = None,
- budget_milli_node_hours: int = 1000,
+ dataset: Optional[
+ Union[
+ datasets.ImageDataset,
+ datasets.TabularDataset,
+ datasets.TextDataset,
+ datasets.VideoDataset,
+ ]
+ ] = None,
+ annotation_schema_uri: Optional[str] = None,
model_display_name: Optional[str] = None,
- disable_early_stopping: bool = False,
- sync: bool = True,
- ) -> models.Model:
- """Runs the training job and returns a model.
+ model_labels: Optional[Dict[str, str]] = None,
+ base_output_dir: Optional[str] = None,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ bigquery_destination: Optional[str] = None,
+ args: Optional[List[Union[str, float, int]]] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ replica_count: int = 1,
+ machine_type: str = "n1-standard-4",
+ accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
+ accelerator_count: int = 0,
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
+ sync=True,
+ create_request_timeout: Optional[float] = None,
+ ) -> Optional[models.Model]:
+ """Runs the custom training job.
+
+ Distributed Training Support:
+ If replica count = 1 then one chief replica will be provisioned. If
+ replica_count > 1 the remainder will be provisioned as a worker replica pool.
+ ie: replica_count = 10 will result in 1 chief and 9 workers
+ All replicas have same machine_type, accelerator_type, and accelerator_count
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
- dataset (datasets.TabularDataset):
- Required. The dataset within the same Project from which data will be used to train the Model. The
- Dataset must use schema compatible with Model being trained,
- and what is compatible should be described in the used
- TrainingPipeline's [training_task_definition]
- [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
- For tabular Datasets, all their data is exported to
- training, to pick and choose from.
+ dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset]):
+ Vertex AI to fit this training against. Custom training script should
+ retrieve datasets through passed in environment variables uris:
+
+ os.environ["AIP_TRAINING_DATA_URI"]
+ os.environ["AIP_VALIDATION_DATA_URI"]
+ os.environ["AIP_TEST_DATA_URI"]
+
+ Additionally the dataset format is passed in as:
+
+ os.environ["AIP_DATA_FORMAT"]
+ annotation_schema_uri (str):
+ Google Cloud Storage URI points to a YAML file describing
+ annotation schema. The schema is defined as an OpenAPI 3.0.2
+ [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
+ that can be used here are found in
+ gs://google-cloud-aiplatform/schema/dataset/annotation/,
+ note that the chosen schema must be consistent with
+ ``metadata``
+ of the Dataset specified by
+ ``dataset_id``.
+
+ Only Annotations that both match this schema and belong to
+ DataItems not ignored by the split method are used in
+ respectively training, validation or test role, depending on
+ the role of the DataItem they are on.
+
+ When used in conjunction with
+ ``annotations_filter``,
+ the Annotations used for training are filtered by both
+ ``annotations_filter``
+ and
+ ``annotation_schema_uri``.
+ model_display_name (str):
+ If the script produces a managed Vertex AI Model. The display name of
+ the Model. The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+
+ If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ base_output_dir (str):
+ GCS output directory of job. If not provided a
+ timestamped directory in the staging directory will be used.
+
+ Vertex AI sets the following environment variables when it runs your training code:
+
+ - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. /model/
+ - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. /checkpoints/
+ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/
+
+ service_account (str):
+ Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+ Private services access must already be configured for the network.
+ If left unspecified, the job is not peered with any network.
+ bigquery_destination (str):
+ Provide this field if `dataset` is a BiqQuery dataset.
+ The BigQuery project location where the training data is to
+ be written to. In the given project a new dataset is created
+ with name
+ ``dataset___``
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
+ training input data will be written into that dataset. In
+ the dataset three tables will be created, ``training``,
+ ``validation`` and ``test``.
+
+ - AIP_DATA_FORMAT = "bigquery".
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
+ args (List[Unions[str, int, float]]):
+ Command line arguments to be passed to the Python script.
+ environment_variables (Dict[str, str]):
+ Environment variables to be passed to the container.
+ Should be a dictionary where keys are environment variable names
+ and values are environment variable values for those names.
+ At most 10 environment variables can be specified.
+ The Name of the environment variable must be unique.
+
+ environment_variables = {
+ 'MY_KEY': 'MY_VALUE'
+ }
+ replica_count (int):
+ The number of worker replicas. If replica count = 1 then one chief
+ replica will be provisioned. If replica_count > 1 the remainder will be
+ provisioned as a worker replica pool.
+ machine_type (str):
+ The type of machine to use for training.
+ accelerator_type (str):
+ Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED,
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4,
+ NVIDIA_TESLA_T4
+ accelerator_count (int):
+ The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Type of the boot disk, default is `pd-ssd`.
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Size in GB of the boot disk, default is 100GB.
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ Optional. The type of machine to use for reduction server.
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
training_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -2471,111 +3750,235 @@ def run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
- weight_column (str):
- Optional. Name of the column that should be used as the weight column.
- Higher values in this column give more importance to the row
- during Model training. The column must have numeric values between 0 and
- 10000 inclusively, and 0 value means that the row is ignored.
- If the weight column field is not set, then all rows are assumed to have
- equal weight of 1.
- budget_milli_node_hours (int):
- Optional. The train budget of creating this Model, expressed in milli node
- hours i.e. 1,000 value in this field means 1 node hour.
- The training cost of the model will not exceed this budget. The final
- cost will be attempted to be close to the budget, though may end up
- being (even) noticeably smaller - at the backend's discretion. This
- especially may happen when further model training ceases to provide
- any improvements.
- If the budget is set to a value known to be insufficient to train a
- Model for the given training set, the training won't be attempted and
- will error.
- The minimum value is 1000 and the maximum is 72000.
- model_display_name (str):
- Optional. If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
- If not provided upon creation, the job's display_name is used.
- disable_early_stopping (bool):
- Required. If true, the entire budget is used. This disables the early stopping
- feature. By default, the early stopping feature is enabled, which means
- that training might stop before the entire training budget has been
- used, if further training does no longer brings significant improvement
- to the model.
+ Supported only for tabular and time series Datasets.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
Raises:
- RuntimeError: If Training job has already been run or is waiting to run.
+ RuntimeError: If Training job has already been run, staging_bucket has not
+ been set, or model_display_name was provided but required arguments
+ were not provided in constructor.
"""
-
- if self._is_waiting_to_run():
- raise RuntimeError("AutoML Tabular Training is already scheduled to run.")
-
- if self._has_run:
- raise RuntimeError("AutoML Tabular Training has already run.")
+ worker_pool_specs, managed_model = self._prepare_and_validate_run(
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ replica_count=replica_count,
+ machine_type=machine_type,
+ accelerator_count=accelerator_count,
+ accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
+ reduction_server_replica_count=reduction_server_replica_count,
+ reduction_server_machine_type=reduction_server_machine_type,
+ )
return self._run(
dataset=dataset,
- target_column=target_column,
+ annotation_schema_uri=annotation_schema_uri,
+ worker_pool_specs=worker_pool_specs,
+ managed_model=managed_model,
+ args=args,
+ environment_variables=environment_variables,
+ base_output_dir=base_output_dir,
+ service_account=service_account,
+ network=network,
+ bigquery_destination=bigquery_destination,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
- weight_column=weight_column,
- budget_milli_node_hours=budget_milli_node_hours,
- model_display_name=model_display_name,
- disable_early_stopping=disable_early_stopping,
+ timestamp_split_column_name=timestamp_split_column_name,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
+ reduction_server_container_uri=reduction_server_container_uri
+ if reduction_server_replica_count > 0
+ else None,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
- @base.optional_sync()
+ @base.optional_sync(construct_object_on_arg="managed_model")
def _run(
self,
- dataset: datasets.TabularDataset,
- target_column: str,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ dataset: Optional[
+ Union[
+ datasets.ImageDataset,
+ datasets.TabularDataset,
+ datasets.TextDataset,
+ datasets.VideoDataset,
+ ]
+ ],
+ annotation_schema_uri: Optional[str],
+ worker_pool_specs: worker_spec_utils._DistributedTrainingSpec,
+ managed_model: Optional[gca_model.Model] = None,
+ args: Optional[List[Union[str, float, int]]] = None,
+ environment_variables: Optional[Dict[str, str]] = None,
+ base_output_dir: Optional[str] = None,
+ service_account: Optional[str] = None,
+ network: Optional[str] = None,
+ bigquery_destination: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
- weight_column: Optional[str] = None,
- budget_milli_node_hours: int = 1000,
- model_display_name: Optional[str] = None,
- disable_early_stopping: bool = False,
- sync: bool = True,
- ) -> models.Model:
- """Runs the training job and returns a model.
+ timestamp_split_column_name: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
+ sync=True,
+ create_request_timeout: Optional[float] = None,
+ ) -> Optional[models.Model]:
+ """Packages local script and launches training_job.
+ Args:
+ dataset (
+ Union[
+ datasets.ImageDataset,
+ datasets.TabularDataset,
+ datasets.TextDataset,
+ datasets.VideoDataset,
+ ]
+ ):
+ Vertex AI to fit this training against.
+ annotation_schema_uri (str):
+ Google Cloud Storage URI points to a YAML file describing
+ annotation schema.
+ worker_pools_spec (worker_spec_utils._DistributedTrainingSpec):
+ Worker pools pecs required to run job.
+ managed_model (gca_model.Model):
+ Model proto if this script produces a Managed Model.
+ args (List[Unions[str, int, float]]):
+ Command line arguments to be passed to the Python script.
+ environment_variables (Dict[str, str]):
+ Environment variables to be passed to the container.
+ Should be a dictionary where keys are environment variable names
+ and values are environment variable values for those names.
+ At most 10 environment variables can be specified.
+ The Name of the environment variable must be unique.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ environment_variables = {
+ 'MY_KEY': 'MY_VALUE'
+ }
+ base_output_dir (str):
+ GCS output directory of job. If not provided a
+ timestamped directory in the staging directory will be used.
- Args:
- dataset (datasets.TabularDataset):
- Required. The dataset within the same Project from which data will be used to train the Model. The
- Dataset must use schema compatible with Model being trained,
- and what is compatible should be described in the used
- TrainingPipeline's [training_task_definition]
- [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
- For tabular Datasets, all their data is exported to
- training, to pick and choose from.
+ Vertex AI sets the following environment variables when it runs your training code:
+
+ - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, i.e. /model/
+ - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, i.e. /checkpoints/
+ - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard logs, i.e. /logs/
+
+ service_account (str):
+ Specifies the service account for workload run-as account.
+ Users submitting jobs must have act-as permission on this run-as account.
+ network (str):
+ The full name of the Compute Engine network to which the job
+ should be peered. For example, projects/12345/global/networks/myVPC.
+ Private services access must already be configured for the network.
+ If left unspecified, the job is not peered with any network.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ bigquery_destination (str):
+ The BigQuery project location where the training data is to
+ be written to. In the given project a new dataset is created
+ with name
+ ``dataset___``
+ where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
+ training input data will be written into that dataset. In
+ the dataset three tables will be created, ``training``,
+ ``validation`` and ``test``.
+
+ - AIP_DATA_FORMAT = "bigquery".
+ - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
+ - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
+ - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
training_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -2586,150 +3989,224 @@ def _run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
- weight_column (str):
- Optional. Name of the column that should be used as the weight column.
- Higher values in this column give more importance to the row
- during Model training. The column must have numeric values between 0 and
- 10000 inclusively, and 0 value means that the row is ignored.
- If the weight column field is not set, then all rows are assumed to have
- equal weight of 1.
- budget_milli_node_hours (int):
- Optional. The train budget of creating this Model, expressed in milli node
- hours i.e. 1,000 value in this field means 1 node hour.
- The training cost of the model will not exceed this budget. The final
- cost will be attempted to be close to the budget, though may end up
- being (even) noticeably smaller - at the backend's discretion. This
- especially may happen when further model training ceases to provide
- any improvements.
- If the budget is set to a value known to be insufficient to train a
- Model for the given training set, the training won't be attempted and
- will error.
- The minimum value is 1000 and the maximum is 72000.
- model_display_name (str):
- Optional. If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
- If not provided upon creation, the job's display_name is used.
- disable_early_stopping (bool):
- Required. If true, the entire budget is used. This disables the early stopping
- feature. By default, the early stopping feature is enabled, which means
- that training might stop before the entire training budget has been
- used, if further training does no longer brings significant improvement
- to the model.
+ Supported only for tabular and time series Datasets.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
- training_task_definition = schema.training_job.definition.automl_tabular
+ for spec_order, spec in enumerate(worker_pool_specs):
- if self._column_transformations is None:
- _LOGGER.info(
- "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
- )
+ if not spec:
+ continue
- column_names = [
- column_name
- for column_name in dataset.column_names
- if column_name != target_column
- ]
- column_transformations = [
- {"auto": {"column_name": column_name}} for column_name in column_names
- ]
+ if (
+ spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
+ and reduction_server_container_uri
+ ):
+ spec["container_spec"] = {
+ "image_uri": reduction_server_container_uri,
+ }
+ else:
+ spec["containerSpec"] = {"imageUri": self._container_uri}
- _LOGGER.info(
- "The column transformation of type 'auto' was set for the following columns: %s."
- % column_names
- )
- else:
- column_transformations = self._column_transformations
+ if self._command:
+ spec["containerSpec"]["command"] = self._command
- training_task_inputs_dict = {
- # required inputs
- "targetColumn": target_column,
- "transformations": column_transformations,
- "trainBudgetMilliNodeHours": budget_milli_node_hours,
- # optional inputs
- "weightColumnName": weight_column,
- "disableEarlyStopping": disable_early_stopping,
- "optimizationObjective": self._optimization_objective,
- "predictionType": self._optimization_prediction_type,
- "optimizationObjectiveRecallValue": self._optimization_objective_recall_value,
- "optimizationObjectivePrecisionValue": self._optimization_objective_precision_value,
- }
+ if args:
+ spec["containerSpec"]["args"] = args
- if model_display_name is None:
- model_display_name = self._display_name
+ if environment_variables:
+ spec["containerSpec"]["env"] = [
+ {"name": key, "value": value}
+ for key, value in environment_variables.items()
+ ]
- model = gca_model.Model(
- display_name=model_display_name,
- encryption_spec=self._model_encryption_spec,
+ (
+ training_task_inputs,
+ base_output_dir,
+ ) = self._prepare_training_task_inputs_and_output_dir(
+ worker_pool_specs=worker_pool_specs,
+ base_output_dir=base_output_dir,
+ service_account=service_account,
+ network=network,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
)
- return self._run_job(
- training_task_definition=training_task_definition,
- training_task_inputs=training_task_inputs_dict,
+ model = self._run_job(
+ training_task_definition=schema.training_job.definition.custom_task,
+ training_task_inputs=training_task_inputs,
dataset=dataset,
+ annotation_schema_uri=annotation_schema_uri,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
- model=model,
+ timestamp_split_column_name=timestamp_split_column_name,
+ model=managed_model,
+ gcs_destination_uri_prefix=base_output_dir,
+ bigquery_destination=bigquery_destination,
+ create_request_timeout=create_request_timeout,
)
- @property
- def _model_upload_fail_string(self) -> str:
- """Helper property for model upload failure."""
- return (
- f"Training Pipeline {self.resource_name} is not configured to upload a "
- "Model."
- )
+ return model
-class AutoMLForecastingTrainingJob(_TrainingJob):
- _supported_training_schemas = (schema.training_job.definition.automl_forecasting,)
+class AutoMLTabularTrainingJob(_TrainingJob):
+ _supported_training_schemas = (schema.training_job.definition.automl_tabular,)
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
+ optimization_prediction_type: str,
optimization_objective: Optional[str] = None,
- column_transformations: Optional[Union[Dict, List[Dict]]] = None,
+ column_specs: Optional[Dict[str, str]] = None,
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
+ optimization_objective_recall_value: Optional[float] = None,
+ optimization_objective_precision_value: Optional[float] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
+ training_encryption_spec_key_name: Optional[str] = None,
+ model_encryption_spec_key_name: Optional[str] = None,
):
- """Constructs a AutoML Forecasting Training Job.
+ """Constructs a AutoML Tabular Training Job.
+
+ Example usage:
+
+ job = training_jobs.AutoMLTabularTrainingJob(
+ display_name="my_display_name",
+ optimization_prediction_type="classification",
+ optimization_objective="minimize-log-loss",
+ column_specs={"column_1": "auto", "column_2": "numeric"},
+ labels={'key': 'value'},
+ )
Args:
display_name (str):
Required. The user-defined name of this TrainingPipeline.
+ optimization_prediction_type (str):
+ The type of prediction the Model is to produce.
+ "classification" - Predict one out of multiple target values is
+ picked for each row.
+ "regression" - Predict a value based on its relation to other values.
+ This type is available only to columns that contain
+ semantically numeric values, i.e. integers or floating
+ point number, even if stored as e.g. strings.
+
optimization_objective (str):
- Optional. Objective function the model is to be optimized towards.
- The training process creates a Model that optimizes the value of the objective
- function over the validation set. The supported optimization objectives:
+ Optional. Objective function the Model is to be optimized towards. The training
+ task creates a Model that maximizes/minimizes the value of the objective
+ function over the validation set.
+
+ The supported optimization objectives depend on the prediction type, and
+ in the case of classification also the number of distinct values in the
+ target column (two distint values -> binary, 3 or more distinct values
+ -> multi class).
+ If the field is not set, the default objective function is used.
+
+ Classification (binary):
+ "maximize-au-roc" (default) - Maximize the area under the receiver
+ operating characteristic (ROC) curve.
+ "minimize-log-loss" - Minimize log loss.
+ "maximize-au-prc" - Maximize the area under the precision-recall curve.
+ "maximize-precision-at-recall" - Maximize precision for a specified
+ recall value.
+ "maximize-recall-at-precision" - Maximize recall for a specified
+ precision value.
+
+ Classification (multi class):
+ "minimize-log-loss" (default) - Minimize log loss.
+
+ Regression:
"minimize-rmse" (default) - Minimize root-mean-squared error (RMSE).
"minimize-mae" - Minimize mean-absolute error (MAE).
"minimize-rmsle" - Minimize root-mean-squared log error (RMSLE).
- "minimize-rmspe" - Minimize root-mean-squared percentage error (RMSPE).
- "minimize-wape-mae" - Minimize the combination of weighted absolute percentage error (WAPE)
- and mean-absolute-error (MAE).
- "minimize-quantile-loss" - Minimize the quantile loss at the defined quantiles.
- (Set this objective to build quantile forecasts.)
- column_transformations (Optional[Union[Dict, List[Dict]]]):
+ column_specs (Dict[str, str]):
+ Optional. Alternative to column_transformations where the keys of the dict
+ are column names and their respective values are one of
+ AutoMLTabularTrainingJob.column_data_types.
+ When creating transformation for BigQuery Struct column, the column
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
+ If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have
+ no transformations defined on.
+ Only one of column_transformations or column_specs should be passed. If none
+ of column_transformations or column_specs is passed, the local credentials
+ being used will try setting column_specs to "auto". To do this, the local
+ credentials require read access to the GCS or BigQuery training data source.
+ column_transformations (List[Dict[str, Dict[str, str]]]):
Optional. Transformations to apply to the input columns (i.e. columns other
than the targetColumn). Each transformation may produce multiple
result values from the column's value, and all are used for training.
When creating transformation for BigQuery Struct column, the column
- should be flattened using "." as the delimiter.
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
If an input column has no transformations on it, such a column is
ignored by the training, except for the targetColumn, which should have
no transformations defined on.
+ Only one of column_transformations or column_specs should be passed.
+ Consider using column_specs as column_transformations will be deprecated
+ eventually. If none of column_transformations or column_specs is passed,
+ the local credentials being used will try setting column_transformations to
+ "auto". To do this, the local credentials require read access to the GCS or
+ BigQuery training data source.
+ optimization_objective_recall_value (float):
+ Optional. Required when maximize-precision-at-recall optimizationObjective was
+ picked, represents the recall value at which the optimization is done.
+
+ The minimum value is 0 and the maximum is 1.0.
+ optimization_objective_precision_value (float):
+ Optional. Required when maximize-recall-at-precision optimizationObjective was
+ picked, represents the precision value at which the optimization is
+ done.
+
+ The minimum value is 0 and the maximum is 1.0.
project (str):
Optional. Project to run training in. Overrides project set in aiplatform.init.
location (str):
@@ -2737,90 +4214,155 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ training_encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the training pipeline. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, this TrainingPipeline will be secured by this key.
+
+ Note: Model trained by this TrainingPipeline is also secured
+ by this key if ``model_to_upload`` is not set separately.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+ model_encryption_spec_key_name (Optional[str]):
+ Optional. The Cloud KMS resource identifier of the customer
+ managed encryption key used to protect the model. Has the
+ form:
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
+ The key needs to be in the same region as where the compute
+ resource is created.
+
+ If set, the trained Model will be secured by this key.
+
+ Overrides encryption_spec_key_name set in aiplatform.init.
+
+ Raises:
+ ValueError: If both column_transformations and column_specs were provided.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
+ training_encryption_spec_key_name=training_encryption_spec_key_name,
+ model_encryption_spec_key_name=model_encryption_spec_key_name,
+ )
+
+ self._column_transformations = (
+ column_transformations_utils.validate_and_get_column_transformations(
+ column_specs, column_transformations
+ )
)
- self._column_transformations = column_transformations
+
self._optimization_objective = optimization_objective
+ self._optimization_prediction_type = optimization_prediction_type
+ self._optimization_objective_recall_value = optimization_objective_recall_value
+ self._optimization_objective_precision_value = (
+ optimization_objective_precision_value
+ )
+
+ self._additional_experiments = []
def run(
self,
- dataset: datasets.TimeSeriesDataset,
+ dataset: datasets.TabularDataset,
target_column: str,
- time_column: str,
- time_series_identifier_column: str,
- unavailable_at_forecast_columns: List[str],
- available_at_forecast_columns: List[str],
- forecast_horizon: int,
- data_granularity_unit: str,
- data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
weight_column: Optional[str] = None,
- time_series_attribute_columns: Optional[List[str]] = None,
- context_window: Optional[int] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ disable_early_stopping: bool = False,
export_evaluated_data_items: bool = False,
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
export_evaluated_data_items_override_destination: bool = False,
- quantiles: Optional[List[float]] = None,
- validation_options: Optional[str] = None,
- budget_milli_node_hours: int = 1000,
- model_display_name: Optional[str] = None,
+ additional_experiments: Optional[List[str]] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- The training data splits are set by default: Roughly 80% will be used for training,
- 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
- dataset (datasets.Dataset):
+ dataset (datasets.TabularDataset):
Required. The dataset within the same Project from which data will be used to train the Model. The
Dataset must use schema compatible with Model being trained,
and what is compatible should be described in the used
TrainingPipeline's [training_task_definition]
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
- For time series Datasets, all their data is exported to
+ For tabular Datasets, all their data is exported to
training, to pick and choose from.
target_column (str):
- Required. Name of the column that the Model is to predict values for.
- time_column (str):
- Required. Name of the column that identifies time order in the time series.
- time_series_identifier_column (str):
- Required. Name of the column that identifies the time series.
- unavailable_at_forecast_columns (List[str]):
- Required. Column names of columns that are unavailable at forecast.
- Each column contains information for the given entity (identified by the
- [time_series_identifier_column]) that is unknown before the forecast
- (e.g. population of a city in a given year, or weather on a given day).
- available_at_forecast_columns (List[str]):
- Required. Column names of columns that are available at forecast.
- Each column contains information for the given entity (identified by the
- [time_series_identifier_column]) that is known at forecast.
- forecast_horizon: (int):
- Required. The amount of time into the future for which forecasted values for the target are
- returned. Expressed in number of units defined by the [data_granularity_unit] and
- [data_granularity_count] field. Inclusive.
- data_granularity_unit (str):
- Required. The data granularity unit. Accepted values are ``minute``,
- ``hour``, ``day``, ``week``, ``month``, ``year``.
- data_granularity_count (int):
- Required. The number of data granularity units between data points in the training
- data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other
- values of [data_granularity_unit], must be 1.
+ Required. The name of the column values of which the Model is to predict.
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ validation_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
- value in the column) must be one of {``TRAIN``,
- ``VALIDATE``, ``TEST``}, and it defines to which set the
+ value in the column) must be one of {``training``,
+ ``validation``, ``test``}, and it defines to which set the
given piece of data is assigned. If for a piece of data the
key is not present or has an invalid value, that piece is
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
weight_column (str):
Optional. Name of the column that should be used as the weight column.
Higher values in this column give more importance to the row
@@ -2828,15 +4370,40 @@ def run(
10000 inclusively, and 0 value means that the row is ignored.
If the weight column field is not set, then all rows are assumed to have
equal weight of 1.
- time_series_attribute_columns (List[str]):
- Optional. Column names that should be used as attribute columns.
- Each column is constant within a time series.
- context_window (int):
- Optional. The amount of time into the past training and prediction data is used for
- model training and prediction respectively. Expressed in number of units defined by the
- [data_granularity_unit] and [data_granularity_count] fields. When not provided uses the
- default value of 0 which means the model sets each series context window to be 0 (also
- known as "cold start"). Inclusive.
+ budget_milli_node_hours (int):
+ Optional. The train budget of creating this Model, expressed in milli node
+ hours i.e. 1,000 value in this field means 1 node hour.
+ The training cost of the model will not exceed this budget. The final
+ cost will be attempted to be close to the budget, though may end up
+ being (even) noticeably smaller - at the backend's discretion. This
+ especially may happen when further model training ceases to provide
+ any improvements.
+ If the budget is set to a value known to be insufficient to train a
+ Model for the given training set, the training won't be attempted and
+ will error.
+ The minimum value is 1000 and the maximum is 72000.
+ model_display_name (str):
+ Optional. If the script produces a managed Vertex AI Model. The display name of
+ the Model. The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+
+ If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ disable_early_stopping (bool):
+ Required. If true, the entire budget is used. This disables the early stopping
+ feature. By default, the early stopping feature is enabled, which means
+ that training might stop before the entire training budget has been
+ used, if further training does no longer brings significant improvement
+ to the model.
export_evaluated_data_items (bool):
Whether to export the test set predictions to a BigQuery table.
If False, then the export is not performed.
@@ -2858,155 +4425,138 @@ def run(
Applies only if [export_evaluated_data_items] is True and
[export_evaluated_data_items_bigquery_destination_uri] is specified.
- quantiles (List[float]):
- Quantiles to use for the `minizmize-quantile-loss`
- [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
- this case.
-
- Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
- Each quantile must be unique.
- validation_options (str):
- Validation options for the data validation component. The available options are:
- "fail-pipeline" - (default), will validate against the validation and fail the pipeline
- if it fails.
- "ignore-validation" - ignore the results of the validation and continue the pipeline
- budget_milli_node_hours (int):
- Optional. The train budget of creating this Model, expressed in milli node
- hours i.e. 1,000 value in this field means 1 node hour.
- The training cost of the model will not exceed this budget. The final
- cost will be attempted to be close to the budget, though may end up
- being (even) noticeably smaller - at the backend's discretion. This
- especially may happen when further model training ceases to provide
- any improvements.
- If the budget is set to a value known to be insufficient to train a
- Model for the given training set, the training won't be attempted and
- will error.
- The minimum value is 1000 and the maximum is 72000.
- model_display_name (str):
- Optional. If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
+ additional_experiments (List[str]):
+ Optional. Additional experiment flags for the automl tables training.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
Raises:
- RuntimeError if Training job has already been run or is waiting to run.
+ RuntimeError: If Training job has already been run or is waiting to run.
"""
+ if model_display_name:
+ utils.validate_display_name(model_display_name)
+ if model_labels:
+ utils.validate_labels(model_labels)
if self._is_waiting_to_run():
- raise RuntimeError(
- "AutoML Forecasting Training is already scheduled to run."
- )
+ raise RuntimeError("AutoML Tabular Training is already scheduled to run.")
if self._has_run:
- raise RuntimeError("AutoML Forecasting Training has already run.")
+ raise RuntimeError("AutoML Tabular Training has already run.")
+
+ if additional_experiments:
+ self._add_additional_experiments(additional_experiments)
return self._run(
dataset=dataset,
target_column=target_column,
- time_column=time_column,
- time_series_identifier_column=time_series_identifier_column,
- unavailable_at_forecast_columns=unavailable_at_forecast_columns,
- available_at_forecast_columns=available_at_forecast_columns,
- forecast_horizon=forecast_horizon,
- data_granularity_unit=data_granularity_unit,
- data_granularity_count=data_granularity_count,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
weight_column=weight_column,
- time_series_attribute_columns=time_series_attribute_columns,
- context_window=context_window,
budget_milli_node_hours=budget_milli_node_hours,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ disable_early_stopping=disable_early_stopping,
export_evaluated_data_items=export_evaluated_data_items,
export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
- quantiles=quantiles,
- validation_options=validation_options,
- model_display_name=model_display_name,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync()
def _run(
self,
- dataset: datasets.TimeSeriesDataset,
+ dataset: datasets.TabularDataset,
target_column: str,
- time_column: str,
- time_series_identifier_column: str,
- unavailable_at_forecast_columns: List[str],
- available_at_forecast_columns: List[str],
- forecast_horizon: int,
- data_granularity_unit: str,
- data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
weight_column: Optional[str] = None,
- time_series_attribute_columns: Optional[List[str]] = None,
- context_window: Optional[int] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ disable_early_stopping: bool = False,
export_evaluated_data_items: bool = False,
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
export_evaluated_data_items_override_destination: bool = False,
- quantiles: Optional[List[float]] = None,
- validation_options: Optional[str] = None,
- budget_milli_node_hours: int = 1000,
- model_display_name: Optional[str] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- The training data splits are set by default: Roughly 80% will be used for training,
- 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
- dataset (datasets.Dataset):
+ dataset (datasets.TabularDataset):
Required. The dataset within the same Project from which data will be used to train the Model. The
Dataset must use schema compatible with Model being trained,
and what is compatible should be described in the used
TrainingPipeline's [training_task_definition]
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
- For time series Datasets, all their data is exported to
+ For tabular Datasets, all their data is exported to
training, to pick and choose from.
target_column (str):
- Required. Name of the column that the Model is to predict values for.
- time_column (str):
- Required. Name of the column that identifies time order in the time series.
- time_series_identifier_column (str):
- Required. Name of the column that identifies the time series.
- unavailable_at_forecast_columns (List[str]):
- Required. Column names of columns that are unavailable at forecast.
- Each column contains information for the given entity (identified by the
- [time_series_identifier_column]) that is unknown before the forecast
- (e.g. population of a city in a given year, or weather on a given day).
- available_at_forecast_columns (List[str]):
- Required. Column names of columns that are available at forecast.
- Each column contains information for the given entity (identified by the
- [time_series_identifier_column]) that is known at forecast.
- forecast_horizon: (int):
- Required. The amount of time into the future for which forecasted values for the target are
- returned. Expressed in number of units defined by the [data_granularity_unit] and
- [data_granularity_count] field. Inclusive.
- data_granularity_unit (str):
- Required. The data granularity unit. Accepted values are ``minute``,
- ``hour``, ``day``, ``week``, ``month``, ``year``.
- data_granularity_count (int):
- Required. The number of data granularity units between data points in the training
- data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other
- values of [data_granularity_unit], must be 1.
+ Required. The name of the column values of which the Model is to predict.
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ validation_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
- value in the column) must be one of {``TRAIN``,
- ``VALIDATE``, ``TEST``}, and it defines to which set the
+ value in the column) must be one of {``training``,
+ ``validation``, ``test``}, and it defines to which set the
given piece of data is assigned. If for a piece of data the
key is not present or has an invalid value, that piece is
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+ Supported only for tabular and time series Datasets.
+ This parameter must be used with training_fraction_split,
+ validation_fraction_split, and test_fraction_split.
weight_column (str):
Optional. Name of the column that should be used as the weight column.
Higher values in this column give more importance to the row
@@ -3014,14 +4564,40 @@ def _run(
10000 inclusively, and 0 value means that the row is ignored.
If the weight column field is not set, then all rows are assumed to have
equal weight of 1.
- time_series_attribute_columns (List[str]):
- Optional. Column names that should be used as attribute columns.
- Each column is constant within a time series.
- context_window (int):
- Optional. The number of periods offset into the past to restrict past sequence, where each
- period is one unit of granularity as defined by [period]. When not provided uses the
- default value of 0 which means the model sets each series historical window to be 0 (also
- known as "cold start"). Inclusive.
+ budget_milli_node_hours (int):
+ Optional. The train budget of creating this Model, expressed in milli node
+ hours i.e. 1,000 value in this field means 1 node hour.
+ The training cost of the model will not exceed this budget. The final
+ cost will be attempted to be close to the budget, though may end up
+ being (even) noticeably smaller - at the backend's discretion. This
+ especially may happen when further model training ceases to provide
+ any improvements.
+ If the budget is set to a value known to be insufficient to train a
+ Model for the given training set, the training won't be attempted and
+ will error.
+ The minimum value is 1000 and the maximum is 72000.
+ model_display_name (str):
+ Optional. If the script produces a managed Vertex AI Model. The display name of
+ the Model. The name can be up to 128 characters long and can be consist
+ of any UTF-8 characters.
+
+ If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
+ disable_early_stopping (bool):
+ Required. If true, the entire budget is used. This disables the early stopping
+ feature. By default, the early stopping feature is enabled, which means
+ that training might stop before the entire training budget has been
+ used, if further training does no longer brings significant improvement
+ to the model.
export_evaluated_data_items (bool):
Whether to export the test set predictions to a BigQuery table.
If False, then the export is not performed.
@@ -3043,68 +4619,50 @@ def _run(
Applies only if [export_evaluated_data_items] is True and
[export_evaluated_data_items_bigquery_destination_uri] is specified.
- quantiles (List[float]):
- Quantiles to use for the `minizmize-quantile-loss`
- [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
- this case.
-
- Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
- Each quantile must be unique.
- validation_options (str):
- Validation options for the data validation component. The available options are:
- "fail-pipeline" - (default), will validate against the validation and fail the pipeline
- if it fails.
- "ignore-validation" - ignore the results of the validation and continue the pipeline
- budget_milli_node_hours (int):
- Optional. The train budget of creating this Model, expressed in milli node
- hours i.e. 1,000 value in this field means 1 node hour.
- The training cost of the model will not exceed this budget. The final
- cost will be attempted to be close to the budget, though may end up
- being (even) noticeably smaller - at the backend's discretion. This
- especially may happen when further model training ceases to provide
- any improvements.
- If the budget is set to a value known to be insufficient to train a
- Model for the given training set, the training won't be attempted and
- will error.
- The minimum value is 1000 and the maximum is 72000.
- model_display_name (str):
- Optional. If the script produces a managed Vertex AI Model. The display name of
- the Model. The name can be up to 128 characters long and can be consist
- of any UTF-8 characters.
-
- If not provided upon creation, the job's display_name is used.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
+
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
- training_task_definition = schema.training_job.definition.automl_forecasting
+ training_task_definition = schema.training_job.definition.automl_tabular
+
+ # auto-populate transformations
+ if self._column_transformations is None:
+ _LOGGER.info(
+ "No column transformations provided, so now retrieving columns from dataset in order to set default column transformations."
+ )
+
+ (
+ self._column_transformations,
+ column_names,
+ ) = column_transformations_utils.get_default_column_transformations(
+ dataset=dataset, target_column=target_column
+ )
+
+ _LOGGER.info(
+ "The column transformation of type 'auto' was set for the following columns: %s."
+ % column_names
+ )
training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
- "timeColumn": time_column,
- "timeSeriesIdentifierColumn": time_series_identifier_column,
- "timeSeriesAttributeColumns": time_series_attribute_columns,
- "unavailableAtForecastColumns": unavailable_at_forecast_columns,
- "availableAtForecastColumns": available_at_forecast_columns,
- "forecastHorizon": forecast_horizon,
- "dataGranularity": {
- "unit": data_granularity_unit,
- "quantity": data_granularity_count,
- },
"transformations": self._column_transformations,
"trainBudgetMilliNodeHours": budget_milli_node_hours,
# optional inputs
- "weightColumn": weight_column,
- "contextWindow": context_window,
- "quantiles": quantiles,
- "validationOptions": validation_options,
+ "weightColumnName": weight_column,
+ "disableEarlyStopping": disable_early_stopping,
"optimizationObjective": self._optimization_objective,
+ "predictionType": self._optimization_prediction_type,
+ "optimizationObjectiveRecallValue": self._optimization_objective_recall_value,
+ "optimizationObjectivePrecisionValue": self._optimization_objective_precision_value,
}
final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
@@ -3119,20 +4677,28 @@ def _run(
"overrideExistingTable": export_evaluated_data_items_override_destination,
}
- if model_display_name is None:
- model_display_name = self._display_name
+ if self._additional_experiments:
+ training_task_inputs_dict[
+ "additionalExperiments"
+ ] = self._additional_experiments
- model = gca_model.Model(display_name=model_display_name)
+ model = gca_model.Model(
+ display_name=model_display_name or self._display_name,
+ labels=model_labels or self._labels,
+ encryption_spec=self._model_encryption_spec,
+ )
return self._run_job(
training_task_definition=training_task_definition,
training_task_inputs=training_task_inputs_dict,
dataset=dataset,
- training_fraction_split=0.8,
- validation_fraction_split=0.1,
- test_fraction_split=0.1,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
model=model,
+ create_request_timeout=create_request_timeout,
)
@property
@@ -3143,6 +4709,223 @@ def _model_upload_fail_string(self) -> str:
"Model."
)
+ def _add_additional_experiments(self, additional_experiments: List[str]):
+ """Add experiment flags to the training job.
+ Args:
+ additional_experiments (List[str]):
+ Experiment flags that can enable some experimental training features.
+ """
+ self._additional_experiments.extend(additional_experiments)
+
+ @staticmethod
+ def get_auto_column_specs(
+ dataset: datasets.TabularDataset,
+ target_column: str,
+ ) -> Dict[str, str]:
+ """Returns a dict with all non-target columns as keys and 'auto' as values.
+
+ Example usage:
+
+ column_specs = training_jobs.AutoMLTabularTrainingJob.get_auto_column_specs(
+ dataset=my_dataset,
+ target_column="my_target_column",
+ )
+
+ Args:
+ dataset (datasets.TabularDataset):
+ Required. Intended dataset.
+ target_column(str):
+ Required. Intended target column.
+ Returns:
+ Dict[str, str]
+ Column names as keys and 'auto' as values
+ """
+ column_names = [
+ column for column in dataset.column_names if column != target_column
+ ]
+ column_specs = {column: "auto" for column in column_names}
+ return column_specs
+
+ class column_data_types:
+ AUTO = "auto"
+ NUMERIC = "numeric"
+ CATEGORICAL = "categorical"
+ TIMESTAMP = "timestamp"
+ TEXT = "text"
+ REPEATED_NUMERIC = "repeated_numeric"
+ REPEATED_CATEGORICAL = "repeated_categorical"
+ REPEATED_TEXT = "repeated_text"
+
+
+class AutoMLForecastingTrainingJob(_ForecastingTrainingJob):
+ _model_type = "AutoML"
+ _training_task_definition = schema.training_job.definition.automl_forecasting
+ _supported_training_schemas = (schema.training_job.definition.automl_forecasting,)
+
+ def run(
+ self,
+ dataset: datasets.TimeSeriesDataset,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ additional_experiments: Optional[List[str]] = None,
+ hierarchy_group_columns: Optional[List[str]] = None,
+ hierarchy_group_total_weight: Optional[float] = None,
+ hierarchy_temporal_total_weight: Optional[float] = None,
+ hierarchy_group_temporal_total_weight: Optional[float] = None,
+ window_column: Optional[str] = None,
+ window_stride_length: Optional[int] = None,
+ window_max_count: Optional[int] = None,
+ holiday_regions: Optional[List[str]] = None,
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> models.Model:
+ return super().run(
+ dataset=dataset,
+ target_column=target_column,
+ time_column=time_column,
+ time_series_identifier_column=time_series_identifier_column,
+ unavailable_at_forecast_columns=unavailable_at_forecast_columns,
+ available_at_forecast_columns=available_at_forecast_columns,
+ forecast_horizon=forecast_horizon,
+ data_granularity_unit=data_granularity_unit,
+ data_granularity_count=data_granularity_count,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
+ weight_column=weight_column,
+ time_series_attribute_columns=time_series_attribute_columns,
+ context_window=context_window,
+ budget_milli_node_hours=budget_milli_node_hours,
+ export_evaluated_data_items=export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
+ export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
+ quantiles=quantiles,
+ validation_options=validation_options,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ additional_experiments=additional_experiments,
+ hierarchy_group_columns=hierarchy_group_columns,
+ hierarchy_group_total_weight=hierarchy_group_total_weight,
+ hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
+ hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
+ window_column=window_column,
+ window_stride_length=window_stride_length,
+ window_max_count=window_max_count,
+ holiday_regions=holiday_regions,
+ sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
+
+
+class SequenceToSequencePlusForecastingTrainingJob(_ForecastingTrainingJob):
+ _model_type = "Seq2Seq"
+ _training_task_definition = schema.training_job.definition.seq2seq_plus_forecasting
+ _supported_training_schemas = (
+ schema.training_job.definition.seq2seq_plus_forecasting,
+ )
+
+ def run(
+ self,
+ dataset: datasets.TimeSeriesDataset,
+ target_column: str,
+ time_column: str,
+ time_series_identifier_column: str,
+ unavailable_at_forecast_columns: List[str],
+ available_at_forecast_columns: List[str],
+ forecast_horizon: int,
+ data_granularity_unit: str,
+ data_granularity_count: int,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ weight_column: Optional[str] = None,
+ time_series_attribute_columns: Optional[List[str]] = None,
+ context_window: Optional[int] = None,
+ export_evaluated_data_items: bool = False,
+ export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
+ export_evaluated_data_items_override_destination: bool = False,
+ quantiles: Optional[List[float]] = None,
+ validation_options: Optional[str] = None,
+ budget_milli_node_hours: int = 1000,
+ model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
+ additional_experiments: Optional[List[str]] = None,
+ hierarchy_group_columns: Optional[List[str]] = None,
+ hierarchy_group_total_weight: Optional[float] = None,
+ hierarchy_temporal_total_weight: Optional[float] = None,
+ hierarchy_group_temporal_total_weight: Optional[float] = None,
+ window_column: Optional[str] = None,
+ window_stride_length: Optional[int] = None,
+ window_max_count: Optional[int] = None,
+ holiday_regions: Optional[List[str]] = None,
+ sync: bool = True,
+ create_request_timeout: Optional[float] = None,
+ ) -> models.Model:
+ return super().run(
+ dataset=dataset,
+ target_column=target_column,
+ time_column=time_column,
+ time_series_identifier_column=time_series_identifier_column,
+ unavailable_at_forecast_columns=unavailable_at_forecast_columns,
+ available_at_forecast_columns=available_at_forecast_columns,
+ forecast_horizon=forecast_horizon,
+ data_granularity_unit=data_granularity_unit,
+ data_granularity_count=data_granularity_count,
+ training_fraction_split=training_fraction_split,
+ validation_fraction_split=validation_fraction_split,
+ test_fraction_split=test_fraction_split,
+ predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
+ weight_column=weight_column,
+ time_series_attribute_columns=time_series_attribute_columns,
+ context_window=context_window,
+ budget_milli_node_hours=budget_milli_node_hours,
+ export_evaluated_data_items=export_evaluated_data_items,
+ export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
+ export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
+ quantiles=quantiles,
+ validation_options=validation_options,
+ model_display_name=model_display_name,
+ model_labels=model_labels,
+ additional_experiments=additional_experiments,
+ hierarchy_group_columns=hierarchy_group_columns,
+ hierarchy_group_total_weight=hierarchy_group_total_weight,
+ hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
+ hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
+ window_column=window_column,
+ window_stride_length=window_stride_length,
+ window_max_count=window_max_count,
+ holiday_regions=holiday_regions,
+ sync=sync,
+ create_request_timeout=create_request_timeout,
+ )
+
class AutoMLImageTrainingJob(_TrainingJob):
_supported_training_schemas = (
@@ -3152,7 +4935,7 @@ class AutoMLImageTrainingJob(_TrainingJob):
def __init__(
self,
- display_name: str,
+ display_name: Optional[str] = None,
prediction_type: str = "classification",
multi_label: bool = False,
model_type: str = "CLOUD",
@@ -3160,6 +4943,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
):
@@ -3167,7 +4951,7 @@ def __init__(
Args:
display_name (str):
- Required. The user-defined name of this TrainingPipeline.
+ Optional. The user-defined name of this TrainingPipeline.
prediction_type (str):
The type of prediction the Model is to produce, one of:
"classification" - Predict one out of multiple target values is
@@ -3225,6 +5009,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -3253,6 +5047,8 @@ def __init__(
Raises:
ValueError: When an invalid prediction_type or model_type is provided.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
valid_model_types = constants.AUTOML_IMAGE_PREDICTION_MODEL_TYPES.get(
prediction_type, None
@@ -3285,6 +5081,7 @@ def __init__(
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
)
@@ -3297,22 +5094,39 @@ def __init__(
def run(
self,
dataset: datasets.ImageDataset,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
- budget_milli_node_hours: int = 1000,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
+ budget_milli_node_hours: Optional[int] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
disable_early_stopping: bool = False,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the AutoML Image training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.ImageDataset):
@@ -3323,31 +5137,70 @@ def run(
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
For tabular Datasets, all their data is exported to
training, to pick and choose from.
- training_fraction_split: float = 0.8
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- validation_fraction_split: float = 0.1
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
- test_fraction_split: float = 0.1
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
- budget_milli_node_hours: int = 1000
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ validation_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ budget_milli_node_hours (int):
Optional. The train budget of creating this Model, expressed in milli node
hours i.e. 1,000 value in this field means 1 node hour.
+
+ Defaults by `prediction_type`:
+
+ `classification` - For Cloud models the budget must be: 8,000 - 800,000
+ milli node hours (inclusive). The default value is 192,000 which
+ represents one day in wall time, assuming 8 nodes are used.
+ `object_detection` - For Cloud models the budget must be: 20,000 - 900,000
+ milli node hours (inclusive). The default value is 216,000 which represents
+ one day in wall time, assuming 9 nodes are used.
+
The training cost of the model will not exceed this budget. The final
cost will be attempted to be close to the budget, though may end up
being (even) noticeably smaller - at the backend's discretion. This
especially may happen when further model training ceases to provide
- any improvements.
- If the budget is set to a value known to be insufficient to train a
- Model for the given training set, the training won't be attempted and
+ any improvements. If the budget is set to a value known to be insufficient to
+ train a Model for the given training set, the training won't be attempted and
will error.
- The minimum value is 1000 and the maximum is 72000.
model_display_name (str):
Optional. The display name of the managed Vertex AI Model. The name
can be up to 128 characters long and can be consist of any UTF-8
characters. If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
disable_early_stopping: bool = False
Required. If true, the entire budget is used. This disables the early stopping
feature. By default, the early stopping feature is enabled, which means
@@ -3358,14 +5211,21 @@ def run(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
Raises:
RuntimeError: If Training job has already been run or is waiting to run.
"""
+ if model_display_name:
+ utils.validate_display_name(model_display_name)
+ if model_labels:
+ utils.validate_labels(model_labels)
+
if self._is_waiting_to_run():
raise RuntimeError("AutoML Image Training is already scheduled to run.")
@@ -3378,10 +5238,15 @@ def run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
budget_milli_node_hours=budget_milli_node_hours,
model_display_name=model_display_name,
+ model_labels=model_labels,
disable_early_stopping=disable_early_stopping,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync()
@@ -3389,22 +5254,39 @@ def _run(
self,
dataset: datasets.ImageDataset,
base_model: Optional[models.Model] = None,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
budget_milli_node_hours: int = 1000,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
disable_early_stopping: bool = False,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.ImageDataset):
@@ -3422,14 +5304,35 @@ def _run(
must be in the same Project and Location as the new Model to train,
and have the same model_type.
training_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
budget_milli_node_hours (int):
Optional. The train budget of creating this Model, expressed in milli node
hours i.e. 1,000 value in this field means 1 node hour.
@@ -3448,6 +5351,16 @@ def _run(
characters. If a `base_model` was provided, the display_name in the
base_model will be overritten with this value. If not provided upon
creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
disable_early_stopping (bool):
Required. If true, the entire budget is used. This disables the early stopping
feature. By default, the early stopping feature is enabled, which means
@@ -3458,10 +5371,12 @@ def _run(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
# Retrieve the objective-specific training task schema based on prediction_type
@@ -3484,6 +5399,7 @@ def _run(
model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec)
model_tbt.display_name = model_display_name or self._display_name
+ model_tbt.labels = model_labels or self._labels
if base_model:
# Use provided base_model to pass to model_to_upload causing the
@@ -3501,7 +5417,11 @@ def _run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
model=model_tbt,
+ create_request_timeout=create_request_timeout,
)
@property
@@ -3523,6 +5443,7 @@ class CustomPythonPackageTrainingJob(_CustomTrainingJob):
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
python_package_gcs_uri: str,
python_module_name: str,
@@ -3541,6 +5462,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
@@ -3554,7 +5476,8 @@ def __init__(
container_uri='gcr.io/cloud-aiplatform/training/tf-cpu.2-2:latest',
model_serving_container_image_uri='gcr.io/my-trainer/serving:1',
model_serving_container_predict_route='predict',
- model_serving_container_health_route='metadata
+ model_serving_container_health_route='metadata,
+ labels={'key': 'value'},
)
Usage with Dataset:
@@ -3566,14 +5489,16 @@ def __init__(
job.run(
ds,
replica_count=1,
- model_display_name='my-trained-model'
+ model_display_name='my-trained-model',
+ model_labels={'key': 'value'},
)
Usage without Dataset:
job.run(
replica_count=1,
- model_display_name='my-trained-model'
+ model_display_name='my-trained-model',
+ model_labels={'key': 'value'},
)
To ensure your model gets saved in Vertex AI, write your saved model to
@@ -3682,6 +5607,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -3711,11 +5646,14 @@ def __init__(
Bucket used to stage source and training artifacts. Overrides
staging_bucket set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
container_uri=container_uri,
@@ -3748,21 +5686,36 @@ def run(
] = None,
annotation_schema_uri: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
bigquery_destination: Optional[str] = None,
args: Optional[List[Union[str, float, int]]] = None,
environment_variables: Optional[Dict[str, str]] = None,
- replica_count: int = 0,
+ replica_count: int = 1,
machine_type: str = "n1-standard-4",
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
accelerator_count: int = 0,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Runs the custom training job.
@@ -3772,12 +5725,36 @@ def run(
ie: replica_count = 10 will result in 1 chief and 9 workers
All replicas have same machine_type, accelerator_type, and accelerator_count
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI.If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
+
+ Predefined splits:
+ Assigns input data to training, validation, and test sets based on the value of a provided key.
+ If using predefined splits, ``predefined_split_column_name`` must be provided.
+ Supported only for tabular Datasets.
+
+ Timestamp splits:
+ Assigns input data to training, validation, and test sets
+ based on a provided timestamps. The youngest data pieces are
+ assigned to training set, next to validation set, and the oldest
+ to the test set.
+ Supported only for tabular Datasets.
Args:
dataset (Union[datasets.ImageDataset,datasets.TabularDataset,datasets.TextDataset,datasets.VideoDataset,]):
@@ -3794,7 +5771,7 @@ def run(
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
- [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files
+ [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
@@ -3819,6 +5796,16 @@ def run(
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
base_output_dir (str):
GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
@@ -3876,15 +5863,50 @@ def run(
NVIDIA_TESLA_T4
accelerator_count (int):
The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Type of the boot disk, default is `pd-ssd`.
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Size in GB of the boot disk, default is 100GB.
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ Optional. The type of machine to use for reduction server.
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -3895,21 +5917,63 @@ def run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
worker_pool_specs, managed_model = self._prepare_and_validate_run(
model_display_name=model_display_name,
+ model_labels=model_labels,
replica_count=replica_count,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
+ reduction_server_replica_count=reduction_server_replica_count,
+ reduction_server_machine_type=reduction_server_machine_type,
)
return self._run(
@@ -3925,9 +5989,21 @@ def run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
bigquery_destination=bigquery_destination,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
+ reduction_server_container_uri=reduction_server_container_uri
+ if reduction_server_replica_count > 0
+ else None,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync(construct_object_on_arg="managed_model")
@@ -3949,12 +6025,22 @@ def _run(
base_output_dir: Optional[str] = None,
service_account: Optional[str] = None,
network: Optional[str] = None,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
+ timestamp_split_column_name: Optional[str] = None,
bigquery_destination: Optional[str] = None,
+ timeout: Optional[int] = None,
+ restart_job_on_worker_restart: bool = False,
+ enable_web_access: bool = False,
+ tensorboard: Optional[str] = None,
+ reduction_server_container_uri: Optional[str] = None,
sync=True,
+ create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Packages local script and launches training_job.
@@ -4006,14 +6092,35 @@ def _run(
Private services access must already be configured for the network.
If left unspecified, the job is not peered with any network.
training_fraction_split (float):
- The fraction of the input data that is to be
- used to train the Model.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- The fraction of the input data that is to be
- used to validate the Model.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- The fraction of the input data that is to be
- used to evaluate the Model.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
@@ -4024,30 +6131,81 @@ def _run(
ignored by the pipeline.
Supported only for tabular and time series Datasets.
+ timestamp_split_column_name (str):
+ Optional. The key is a name of one of the Dataset's data
+ columns. The value of the key values of the key (the values in
+ the column) must be in RFC 3339 `date-time` format, where
+ `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
+ piece of data the key is not present or has an invalid value,
+ that piece is ignored by the pipeline.
+
+ Supported only for tabular and time series Datasets.
+ timeout (int):
+ The maximum job running time in seconds. The default is 7 days.
+ restart_job_on_worker_restart (bool):
+ Restarts the entire CustomJob if a worker
+ gets restarted. This feature can be used by
+ distributed training jobs that are not resilient
+ to workers leaving and joining a job.
+ enable_web_access (bool):
+ Whether you want Vertex AI to enable interactive shell access
+ to training containers.
+ https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell
+ tensorboard (str):
+ Optional. The name of a Vertex AI
+ [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
+ resource to which this CustomJob will upload Tensorboard
+ logs. Format:
+ ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
+
+ The training script should write Tensorboard to following Vertex AI environment
+ variable:
+
+ AIP_TENSORBOARD_LOG_DIR
+
+ `service_account` is required with provided `tensorboard`.
+ For more information on configuring your service account please visit:
+ https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
+ reduction_server_container_uri (str):
+ Optional. The Uri of the reduction server container image.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
- for spec in worker_pool_specs:
- spec["python_package_spec"] = {
- "executor_image_uri": self._container_uri,
- "python_module": self._python_module,
- "package_uris": [self._package_gcs_uri],
- }
+ for spec_order, spec in enumerate(worker_pool_specs):
- if args:
- spec["python_package_spec"]["args"] = args
+ if not spec:
+ continue
- if environment_variables:
- spec["python_package_spec"]["env"] = [
- {"name": key, "value": value}
- for key, value in environment_variables.items()
- ]
+ if (
+ spec_order == worker_spec_utils._SPEC_ORDERS["server_spec"]
+ and reduction_server_container_uri
+ ):
+ spec["container_spec"] = {
+ "image_uri": reduction_server_container_uri,
+ }
+ else:
+ spec["python_package_spec"] = {
+ "executor_image_uri": self._container_uri,
+ "python_module": self._python_module,
+ "package_uris": [self._package_gcs_uri],
+ }
+
+ if args:
+ spec["python_package_spec"]["args"] = args
+
+ if environment_variables:
+ spec["python_package_spec"]["env"] = [
+ {"name": key, "value": value}
+ for key, value in environment_variables.items()
+ ]
(
training_task_inputs,
@@ -4057,6 +6215,10 @@ def _run(
base_output_dir=base_output_dir,
service_account=service_account,
network=network,
+ timeout=timeout,
+ restart_job_on_worker_restart=restart_job_on_worker_restart,
+ enable_web_access=enable_web_access,
+ tensorboard=tensorboard,
)
model = self._run_job(
@@ -4067,10 +6229,15 @@ def _run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
+ timestamp_split_column_name=timestamp_split_column_name,
model=managed_model,
gcs_destination_uri_prefix=base_output_dir,
bigquery_destination=bigquery_destination,
+ create_request_timeout=create_request_timeout,
)
return model
@@ -4086,12 +6253,13 @@ class AutoMLVideoTrainingJob(_TrainingJob):
def __init__(
self,
- display_name: str,
+ display_name: Optional[str] = None,
prediction_type: str = "classification",
model_type: str = "CLOUD",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
):
@@ -4108,7 +6276,7 @@ def __init__(
multiple objects in shots and segments. You can use these
models to track objects in your videos according to your
own pre-defined, custom labels.
- "action_recognition" - A video action reconition model pinpoints
+ "action_recognition" - A video action recognition model pinpoints
the location of actions with short temporal durations (~1 second).
model_type: str = "CLOUD"
Required. One of the following:
@@ -4141,6 +6309,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -4169,6 +6347,9 @@ def __init__(
Raises:
ValueError: When an invalid prediction_type and/or model_type is provided.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
+
valid_model_types = constants.AUTOML_VIDEO_PREDICTION_MODEL_TYPES.get(
prediction_type, None
)
@@ -4190,6 +6371,7 @@ def __init__(
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
)
@@ -4200,17 +6382,33 @@ def __init__(
def run(
self,
dataset: datasets.VideoDataset,
- training_fraction_split: float = 0.8,
- test_fraction_split: float = 0.2,
+ training_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
- """Runs the AutoML Image training job and returns a model.
-
- Data fraction splits:
- ``training_fraction_split``, and ``test_fraction_split`` may optionally
- be provided, they must sum to up to 1. If none of the fractions are set,
- by default roughly 80% of data will be used for training, and 20% for test.
+ """Runs the AutoML Video training job and returns a model.
+
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ ``training_fraction_split``, and ``test_fraction_split`` may optionally
+ be provided, they must sum to up to 1. If none of the fractions are set,
+ by default roughly 80% of data will be used for training, and 20% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.VideoDataset):
@@ -4221,28 +6419,59 @@ def run(
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
For tabular Datasets, all their data is exported to
training, to pick and choose from.
- training_fraction_split: float = 0.8
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- test_fraction_split: float = 0.2
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
model_display_name (str):
Optional. The display name of the managed Vertex AI Model. The name
can be up to 128 characters long and can be consist of any UTF-8
characters. If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
sync: bool = True
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
Raises:
RuntimeError: If Training job has already been run or is waiting to run.
"""
+ if model_display_name:
+ utils.validate_display_name(model_display_name)
+ if model_labels:
+ utils.validate_labels(model_labels)
+
if self._is_waiting_to_run():
raise RuntimeError("AutoML Video Training is already scheduled to run.")
@@ -4253,25 +6482,45 @@ def run(
dataset=dataset,
training_fraction_split=training_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ test_filter_split=test_filter_split,
model_display_name=model_display_name,
+ model_labels=model_labels,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync()
def _run(
self,
dataset: datasets.VideoDataset,
- training_fraction_split: float = 0.8,
- test_fraction_split: float = 0.2,
+ training_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally
- be provided, they must sum to up to 1. If none of the fractions are set,
- by default roughly 80% of data will be used for training, and 20% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, and ``test_fraction_split`` may optionally
+ be provided, they must sum to up to 1. If none of the fractions are set,
+ by default roughly 80% of data will be used for training, and 20% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.VideoDataset):
@@ -4283,25 +6532,51 @@ def _run(
For tabular Datasets, all their data is exported to
training, to pick and choose from.
training_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
model_display_name (str):
Optional. The display name of the managed Vertex AI Model. The name
can be up to 128 characters long and can be consist of any UTF-8
characters. If a `base_model` was provided, the display_name in the
base_model will be overritten with this value. If not provided upon
creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
# Retrieve the objective-specific training task schema based on prediction_type
@@ -4316,15 +6591,26 @@ def _run(
# gca Model to be trained
model_tbt = gca_model.Model(encryption_spec=self._model_encryption_spec)
model_tbt.display_name = model_display_name or self._display_name
+ model_tbt.labels = model_labels or self._labels
+
+ # AutoMLVideo does not support validation, so pass in '-' if any other filter split is provided.
+ validation_filter_split = (
+ "-"
+ if all([training_filter_split is not None, test_filter_split is not None])
+ else None
+ )
return self._run_job(
training_task_definition=training_task_definition,
training_task_inputs=training_task_inputs_dict,
dataset=dataset,
training_fraction_split=training_fraction_split,
- validation_fraction_split=0.0,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
model=model_tbt,
+ create_request_timeout=create_request_timeout,
)
@property
@@ -4345,6 +6631,7 @@ class AutoMLTextTrainingJob(_TrainingJob):
def __init__(
self,
+ # TODO(b/223262536): Make display_name parameter fully optional in next major release
display_name: str,
prediction_type: str,
multi_label: bool = False,
@@ -4352,6 +6639,7 @@ def __init__(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
+ labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
):
@@ -4395,6 +6683,16 @@ def __init__(
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to run call training service. Overrides
credentials set in aiplatform.init.
+ labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize TrainingPipelines.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
@@ -4421,11 +6719,14 @@ def __init__(
Overrides encryption_spec_key_name set in aiplatform.init.
"""
+ if not display_name:
+ display_name = self.__class__._generate_display_name()
super().__init__(
display_name=display_name,
project=project,
location=location,
credentials=credentials,
+ labels=labels,
training_encryption_spec_key_name=training_encryption_spec_key_name,
model_encryption_spec_key_name=model_encryption_spec_key_name,
)
@@ -4438,8 +6739,10 @@ def __init__(
schema.training_job.definition.automl_text_classification
)
- training_task_inputs_dict = training_job_inputs.AutoMlTextClassificationInputs(
- multi_label=multi_label
+ training_task_inputs_dict = (
+ training_job_inputs.AutoMlTextClassificationInputs(
+ multi_label=multi_label
+ )
)
elif prediction_type == "extraction":
training_task_definition = (
@@ -4466,20 +6769,37 @@ def __init__(
def run(
self,
dataset: datasets.TextDataset,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.TextDataset):
@@ -4488,25 +6808,58 @@ def run(
and what is compatible should be described in the used
TrainingPipeline's [training_task_definition]
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
- training_fraction_split: float = 0.8
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
- validation_fraction_split: float = 0.1
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
- test_fraction_split: float = 0.1
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ training_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
+ validation_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
+ test_fraction_split (float):
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
model_display_name (str):
Optional. The display name of the managed Vertex AI Model.
The name can be up to 128 characters long and can consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels..
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds
Returns:
model: The trained Vertex AI Model resource.
@@ -4514,6 +6867,11 @@ def run(
RuntimeError: If Training job has already been run or is waiting to run.
"""
+ if model_display_name:
+ utils.validate_display_name(model_display_name)
+ if model_labels:
+ utils.validate_labels(model_labels)
+
if self._is_waiting_to_run():
raise RuntimeError("AutoML Text Training is already scheduled to run.")
@@ -4525,28 +6883,50 @@ def run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
model_display_name=model_display_name,
+ model_labels=model_labels,
sync=sync,
+ create_request_timeout=create_request_timeout,
)
@base.optional_sync()
def _run(
self,
dataset: datasets.TextDataset,
- training_fraction_split: float = 0.8,
- validation_fraction_split: float = 0.1,
- test_fraction_split: float = 0.1,
+ training_fraction_split: Optional[float] = None,
+ validation_fraction_split: Optional[float] = None,
+ test_fraction_split: Optional[float] = None,
+ training_filter_split: Optional[str] = None,
+ validation_filter_split: Optional[str] = None,
+ test_filter_split: Optional[str] = None,
model_display_name: Optional[str] = None,
+ model_labels: Optional[Dict[str, str]] = None,
sync: bool = True,
+ create_request_timeout: Optional[float] = None,
) -> models.Model:
"""Runs the training job and returns a model.
- Data fraction splits:
- Any of ``training_fraction_split``, ``validation_fraction_split`` and
- ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
- the provided ones sum to less than 1, the remainder is assigned to sets as
- decided by Vertex AI. If none of the fractions are set, by default roughly 80%
- of data will be used for training, 10% for validation, and 10% for test.
+ If training on a Vertex AI dataset, you can use one of the following split configurations:
+ Data fraction splits:
+ Any of ``training_fraction_split``, ``validation_fraction_split`` and
+ ``test_fraction_split`` may optionally be provided, they must sum to up to 1. If
+ the provided ones sum to less than 1, the remainder is assigned to sets as
+ decided by Vertex AI. If none of the fractions are set, by default roughly 80%
+ of data will be used for training, 10% for validation, and 10% for test.
+
+ Data filter splits:
+ Assigns input data to training, validation, and test sets
+ based on the given filters, data pieces not matched by any
+ filter are ignored. Currently only supported for Datasets
+ containing DataItems.
+ If any of the filters in this message are to match nothing, then
+ they can be set as '-' (the minus sign).
+ If using filter splits, all of ``training_filter_split``, ``validation_filter_split`` and
+ ``test_filter_split`` must be provided.
+ Supported only for unstructured Datasets.
Args:
dataset (datasets.TextDataset):
@@ -4558,35 +6938,66 @@ def _run(
For Text Datasets, all their data is exported to
training, to pick and choose from.
training_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to train the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to train
+ the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to validate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to validate
+ the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
- Required. The fraction of the input data that is to be
- used to evaluate the Model. This is ignored if Dataset is not provided.
+ Optional. The fraction of the input data that is to be used to evaluate
+ the Model. This is ignored if Dataset is not provided.
+ training_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to train the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ validation_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to validate the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
+ test_filter_split (str):
+ Optional. A filter on DataItems of the Dataset. DataItems that match
+ this filter are used to test the Model. A filter with same syntax
+ as the one used in DatasetService.ListDataItems may be used. If a
+ single DataItem is matched by more than one of the FilterSplit filters,
+ then it is assigned to the first set that applies to it in the training,
+ validation, test order. This is ignored if Dataset is not provided.
model_display_name (str):
Optional. If the script produces a managed Vertex AI Model. The display name of
the Model. The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
If not provided upon creation, the job's display_name is used.
+ model_labels (Dict[str, str]):
+ Optional. The labels with user-defined metadata to
+ organize your Models.
+ Label keys and values can be no longer than 64
+ characters (Unicode codepoints), can only
+ contain lowercase letters, numeric characters,
+ underscores and dashes. International characters
+ are allowed.
+ See https://goo.gl/xmQnxf for more information
+ and examples of labels.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
+ create_request_timeout (float):
+ Optional. The timeout for the create request in seconds.
Returns:
model: The trained Vertex AI Model resource or None if training did not
- produce an Vertex AI Model.
+ produce a Vertex AI Model.
"""
- if model_display_name is None:
- model_display_name = self._display_name
-
model = gca_model.Model(
- display_name=model_display_name,
+ display_name=model_display_name or self._display_name,
+ labels=model_labels or self._labels,
encryption_spec=self._model_encryption_spec,
)
@@ -4597,8 +7008,11 @@ def _run(
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
- predefined_split_column_name=None,
+ training_filter_split=training_filter_split,
+ validation_filter_split=validation_filter_split,
+ test_filter_split=test_filter_split,
model=model,
+ create_request_timeout=create_request_timeout,
)
@property
diff --git a/google/cloud/aiplatform/training_utils/__init__.py b/google/cloud/aiplatform/training_utils/__init__.py
new file mode 100644
index 0000000000..0e973c9a40
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/__init__.py
@@ -0,0 +1,15 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst b/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst
new file mode 100644
index 0000000000..6c6cfc1af9
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/README.rst
@@ -0,0 +1,20 @@
+Cloud Profiler
+=================================
+
+Cloud Profiler allows you to profile your remote Vertex AI Training jobs on demand and visualize the results in Vertex Tensorboard.
+
+Quick Start
+------------
+
+To start using the profiler with TensorFlow, update your training script to include the following:
+
+.. code-block:: Python
+
+ from google.cloud.aiplatform.training_utils import cloud_profiler
+ ...
+ cloud_profiler.init()
+
+
+Next, run the job with with a Vertex TensorBoard instance. For full details on how to do this, visit https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-overview
+
+Finally, visit your TensorBoard in your Google Cloud Console, navigate to the "Profile" tab, and click the `Capture Profile` button. This will allow users to capture profiling statistics for the running jobs.
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py b/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py
new file mode 100644
index 0000000000..1b0c5eb925
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/__init__.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform.training_utils.cloud_profiler import initializer
+
+"""
+Initialize the cloud profiler for tensorflow.
+
+Usage:
+from google.cloud.aiplatform.training_utils import cloud_profiler
+
+cloud_profiler.init(profiler='tensorflow')
+"""
+
+init = initializer.initialize
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/cloud_profiler_utils.py b/google/cloud/aiplatform/training_utils/cloud_profiler/cloud_profiler_utils.py
new file mode 100644
index 0000000000..f7f6e8d8f6
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/cloud_profiler_utils.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import_error_msg = (
+ "Could not load the cloud profiler. To use the profiler, "
+ "install the SDK using 'pip install google-cloud-aiplatform[cloud-profiler]'"
+)
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py b/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py
new file mode 100644
index 0000000000..7abc815078
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/initializer.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import threading
+from typing import Optional, Type
+
+from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
+
+try:
+ from werkzeug import serving
+except ImportError as err:
+ raise ImportError(cloud_profiler_utils.import_error_msg) from err
+
+
+from google.cloud.aiplatform.training_utils import environment_variables
+from google.cloud.aiplatform.training_utils.cloud_profiler import webserver
+from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
+from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
+ tf_profiler,
+)
+
+
+# Mapping of available plugins to use
+_AVAILABLE_PLUGINS = {"tensorflow": tf_profiler.TFProfiler}
+
+
+class MissingEnvironmentVariableException(Exception):
+ pass
+
+
+def _build_plugin(
+ plugin: Type[base_plugin.BasePlugin],
+) -> Optional[base_plugin.BasePlugin]:
+ """Builds the plugin given the object.
+
+ Args:
+ plugin (Type[base_plugin]):
+ Required. An uninitialized plugin class.
+
+ Returns:
+ An initialized plugin, or None if plugin cannot be
+ initialized.
+ """
+ if not plugin.can_initialize():
+ logging.warning("Cannot initialize the plugin")
+ return
+
+ plugin.setup()
+
+ if not plugin.post_setup_check():
+ return
+
+ return plugin()
+
+
+def _run_app_thread(server: webserver.WebServer, port: int):
+ """Run the webserver in a separate thread.
+
+ Args:
+ server (webserver.WebServer):
+ Required. A webserver to accept requests.
+ port (int):
+ Required. The port to run the webserver on.
+ """
+ daemon = threading.Thread(
+ name="profile_server",
+ target=serving.run_simple,
+ args=(
+ "0.0.0.0",
+ port,
+ server,
+ ),
+ )
+ daemon.setDaemon(True)
+ daemon.start()
+
+
+def initialize(plugin: str = "tensorflow"):
+ """Initializes the profiling SDK.
+
+ Args:
+ plugin (str):
+ Required. Name of the plugin to initialize.
+ Current options are ["tensorflow"]
+
+ Raises:
+ ValueError:
+ The plugin does not exist.
+ MissingEnvironmentVariableException:
+ An environment variable that is needed is not set.
+ """
+ plugin_obj = _AVAILABLE_PLUGINS.get(plugin)
+
+ if not plugin_obj:
+ raise ValueError(
+ "Plugin {} not available, must choose from {}".format(
+ plugin, _AVAILABLE_PLUGINS.keys()
+ )
+ )
+
+ prof_plugin = _build_plugin(plugin_obj)
+
+ if prof_plugin is None:
+ return
+
+ server = webserver.WebServer([prof_plugin])
+
+ if not environment_variables.http_handler_port:
+ raise MissingEnvironmentVariableException(
+ "'AIP_HTTP_HANDLER_PORT' must be set."
+ )
+
+ port = int(environment_variables.http_handler_port)
+
+ _run_app_thread(server, port)
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py
new file mode 100644
index 0000000000..67b6b40ae9
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/base_plugin.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import abc
+from typing import Callable, Dict
+from werkzeug import Response
+
+
+class BasePlugin(abc.ABC):
+ """Base plugin for cloud training tools endpoints.
+
+ The plugins support registering http handlers to be used for
+ AI Platform training jobs.
+ """
+
+ @staticmethod
+ @abc.abstractmethod
+ def setup() -> None:
+ """Run any setup code for the plugin before webserver is launched."""
+ raise NotImplementedError
+
+ @staticmethod
+ @abc.abstractmethod
+ def can_initialize() -> bool:
+ """Check whether a plugin is able to be initialized.
+
+ Used for checking if correct dependencies are installed, system requirements, etc.
+
+ Returns:
+ Bool indicating whether the plugin can be initialized.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ @abc.abstractmethod
+ def post_setup_check() -> bool:
+ """Check if after initialization, we need to use the plugin.
+
+ Example: Web server only needs to run for main node for training, others
+ just need to have 'setup()' run to start the rpc server.
+
+ Returns:
+ A boolean indicating whether post setup checks pass.
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def get_routes(self) -> Dict[str, Callable[..., Response]]:
+ """Get the mapping from path to handler.
+
+ This is the method in which plugins can assign different routes to
+ different handlers.
+
+ Returns:
+ A mapping from a route to a handler.
+ """
+ raise NotImplementedError
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py
new file mode 100644
index 0000000000..c74d4cfa55
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tensorboard_api.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Helpers for creating a profile request sender for tf profiler plugin."""
+
+import os
+import re
+from typing import Tuple
+
+from tensorboard.uploader import upload_tracker
+from tensorboard.uploader import util
+from tensorboard.uploader.proto import server_info_pb2
+from tensorboard.util import tb_logging
+
+from google.api_core import exceptions
+from google.cloud import aiplatform
+from google.cloud import storage
+from google.cloud.aiplatform.utils import TensorboardClientWithOverride
+from google.cloud.aiplatform.tensorboard import uploader_utils
+from google.cloud.aiplatform.compat.types import tensorboard_experiment
+from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
+from google.cloud.aiplatform import training_utils
+
+logger = tb_logging.get_logger()
+
+
+def _get_api_client() -> TensorboardClientWithOverride:
+ """Creates an Tensorboard API client."""
+ m = re.match(
+ "projects/.*/locations/(.*)/tensorboards/.*",
+ training_utils.environment_variables.tensorboard_resource_name,
+ )
+ region = m[1]
+
+ api_client = aiplatform.initializer.global_config.create_client(
+ client_class=TensorboardClientWithOverride,
+ location_override=region,
+ api_base_path_override=training_utils.environment_variables.tensorboard_api_uri,
+ )
+
+ return api_client
+
+
+def _get_project_id() -> str:
+ """Gets the project id from the tensorboard resource name.
+
+ Returns:
+ Project ID for current project.
+
+ Raises:
+ ValueError: Cannot parse the tensorboard resource name.
+ """
+ m = re.match(
+ "projects/(.*)/locations/.*/tensorboards/.*",
+ training_utils.environment_variables.tensorboard_resource_name,
+ )
+ if not m:
+ raise ValueError(
+ "Incorrect format for tensorboard resource name: %s",
+ training_utils.environment_variables.tensorboard_resource_name,
+ )
+ return m[1]
+
+
+def _make_upload_limits() -> server_info_pb2.UploadLimits:
+ """Creates the upload limits for tensorboard.
+
+ Returns:
+ An UploadLimits object.
+ """
+ upload_limits = server_info_pb2.UploadLimits()
+ upload_limits.min_blob_request_interval = 10
+ upload_limits.max_blob_request_size = 4 * (2**20) - 256 * (2**10)
+ upload_limits.max_blob_size = 10 * (2**30) # 10GiB
+
+ return upload_limits
+
+
+def _get_blob_items(
+ api_client: TensorboardClientWithOverride,
+) -> Tuple[storage.bucket.Bucket, str]:
+ """Gets the blob storage items for the tensorboard resource.
+
+ Args:
+ api_client ():
+ Required. Client go get information about the tensorboard instance.
+
+ Returns:
+ A tuple of storage buckets and the blob storage folder name.
+ """
+ project_id = _get_project_id()
+ tensorboard = api_client.get_tensorboard(
+ name=training_utils.environment_variables.tensorboard_resource_name
+ )
+
+ path_prefix = tensorboard.blob_storage_path_prefix + "/"
+ first_slash_index = path_prefix.find("/")
+ bucket_name = path_prefix[:first_slash_index]
+ blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name)
+ blob_storage_folder = path_prefix[first_slash_index + 1 :]
+
+ return blob_storage_bucket, blob_storage_folder
+
+
+def _get_or_create_experiment(
+ api: TensorboardClientWithOverride, experiment_name: str
+) -> str:
+ """Creates a tensorboard experiment.
+
+ Args:
+ api (TensorboardClientWithOverride):
+ Required. An api for interfacing with tensorboard resources.
+ experiment_name (str):
+ Required. The name of the experiment to get or create.
+
+ Returns:
+ The name of the experiment.
+ """
+ tb_experiment = tensorboard_experiment.TensorboardExperiment()
+
+ try:
+ experiment = api.create_tensorboard_experiment(
+ parent=training_utils.environment_variables.tensorboard_resource_name,
+ tensorboard_experiment=tb_experiment,
+ tensorboard_experiment_id=experiment_name,
+ )
+ except exceptions.AlreadyExists:
+ logger.info("Creating experiment failed. Retrieving experiment.")
+ experiment_name = os.path.join(
+ training_utils.environment_variables.tensorboard_resource_name,
+ "experiments",
+ experiment_name,
+ )
+ experiment = api.get_tensorboard_experiment(name=experiment_name)
+
+ return experiment.name
+
+
+def create_profile_request_sender() -> profile_uploader.ProfileRequestSender:
+ """Creates the `ProfileRequestSender` for the profile plugin.
+
+ A profile request sender is created for the plugin so that after profiling runs
+ have finished, data can be uploaded to the tensorboard backend.
+
+ Returns:
+ A ProfileRequestSender object.
+ """
+ api_client = _get_api_client()
+
+ experiment_name = _get_or_create_experiment(
+ api_client, training_utils.environment_variables.cloud_ml_job_id
+ )
+
+ upload_limits = _make_upload_limits()
+
+ blob_rpc_rate_limiter = util.RateLimiter(
+ upload_limits.min_blob_request_interval / 100
+ )
+
+ blob_storage_bucket, blob_storage_folder = _get_blob_items(
+ api_client,
+ )
+
+ source_bucket = uploader_utils.get_source_bucket(
+ training_utils.environment_variables.tensorboard_log_dir
+ )
+
+ profile_request_sender = profile_uploader.ProfileRequestSender(
+ experiment_name,
+ api_client,
+ upload_limits=upload_limits,
+ blob_rpc_rate_limiter=blob_rpc_rate_limiter,
+ blob_storage_bucket=blob_storage_bucket,
+ blob_storage_folder=blob_storage_folder,
+ source_bucket=source_bucket,
+ tracker=upload_tracker.UploadTracker(verbosity=1),
+ logdir=training_utils.environment_variables.tensorboard_log_dir,
+ )
+
+ return profile_request_sender
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py
new file mode 100644
index 0000000000..81b43145b3
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py
@@ -0,0 +1,352 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""
+
+from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
+
+try:
+ import tensorflow as tf
+ from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
+except ImportError as err:
+ raise ImportError(cloud_profiler_utils.import_error_msg) from err
+
+import argparse
+from collections import namedtuple
+import importlib.util
+import json
+import logging
+from typing import Callable, Dict, Optional
+from urllib import parse
+
+import tensorboard.plugins.base_plugin as tensorboard_base_plugin
+from werkzeug import Response
+
+from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
+from google.cloud.aiplatform.training_utils import environment_variables
+from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types
+from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
+from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
+ tensorboard_api,
+)
+
+
+# TF verison information.
+Version = namedtuple("Version", ["major", "minor", "patch"])
+
+logger = logging.Logger("tf-profiler")
+
+_BASE_TB_ENV_WARNING = (
+ "To set this environment variable, run your training with the 'tensorboard' "
+ "option. For more information on how to run with training with tensorboard, visit "
+ "https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training"
+)
+
+
+def _get_tf_versioning() -> Optional[Version]:
+ """Convert version string to a Version namedtuple for ease of parsing.
+
+ Returns:
+ A version object if finding the version was successful, None otherwise.
+ """
+ version = tf.__version__
+
+ versioning = version.split(".")
+ if len(versioning) != 3:
+ return
+
+ return Version(int(versioning[0]), int(versioning[1]), int(versioning[2]))
+
+
+def _is_compatible_version(version: Version) -> bool:
+ """Check if version is compatible with tf profiling.
+
+ Profiling plugin is available to be used for version >= 2.4.0.
+ While the profiler is available in 2.2.0 >=, some additional dependencies
+ that are included in 2.4.0 >= are also needed for the tensorboard-plugin-profile.
+
+ Profiler:
+ https://www.tensorflow.org/guide/profiler
+ Required commit for tensorboard-plugin-profile:
+ https://github.com/tensorflow/tensorflow/commit/8b9c207242db515daef033e74d69ea5d8e023dc6
+
+ Args:
+ version (Version):
+ Required. `Verison` of tensorflow.
+
+ Returns:
+ Bool indicating wheter version is compatible with profiler.
+ """
+ return version.major >= 2 and version.minor >= 4
+
+
+def _check_tf() -> bool:
+ """Check whether all the tensorflow prereqs are met.
+
+ Returns:
+ True if all requirements met, False otherwise.
+ """
+ # Check tf is installed
+ if importlib.util.find_spec("tensorflow") is None:
+ logger.warning("Tensorflow not installed, cannot initialize profiling plugin")
+ return False
+
+ # Check tensorflow version
+ version = _get_tf_versioning()
+ if version is None:
+ logger.warning(
+ "Could not find major, minor, and patch versions of tensorflow. Version found: %s",
+ version,
+ )
+ return False
+
+ # Check compatibility, introduced in tensorflow >= 2.2.0
+ if not _is_compatible_version(version):
+ logger.warning(
+ "Version %s is incompatible with tf profiler."
+ "To use the profiler, choose a version >= 2.2.0",
+ "%s.%s.%s" % (version.major, version.minor, version.patch),
+ )
+ return False
+
+ # Check for the tf profiler plugin
+ if importlib.util.find_spec("tensorboard_plugin_profile") is None:
+ logger.warning(
+ "Could not import tensorboard_plugin_profile, will not run tf profiling service"
+ )
+ return False
+
+ return True
+
+
+def _create_profiling_context() -> tensorboard_base_plugin.TBContext:
+ """Creates the base context needed for TB Profiler.
+
+ Returns:
+ An initialized `TBContext`.
+ """
+
+ context_flags = argparse.Namespace(master_tpu_unsecure_channel=None)
+
+ context = tensorboard_base_plugin.TBContext(
+ logdir=environment_variables.tensorboard_log_dir,
+ multiplexer=None,
+ flags=context_flags,
+ )
+
+ return context
+
+
+def _host_to_grpc(hostname: str) -> str:
+ """Format a hostname to a grpc address.
+
+ Args:
+ hostname (str):
+ Required. Address in form: `{hostname}:{port}`
+
+ Returns:
+ Address in form of: 'grpc://{hostname}:{port}'
+ """
+ return (
+ "grpc://"
+ + "".join(hostname.split(":")[:-1])
+ + ":"
+ + environment_variables.tf_profiler_port
+ )
+
+
+def _get_hostnames() -> Optional[str]:
+ """Get the hostnames for all servers running.
+
+ Returns:
+ A host formatted by `_host_to_grpc` if obtaining the cluster spec
+ is successful, None otherwise.
+ """
+ cluster_spec = environment_variables.cluster_spec
+ if cluster_spec is None:
+ return
+
+ cluster = cluster_spec.get("cluster", "")
+ if not cluster:
+ return
+
+ hostnames = []
+ for value in cluster.values():
+ hostnames.extend(value)
+
+ return ",".join([_host_to_grpc(x) for x in hostnames])
+
+
+def _update_environ(environ: wsgi_types.Environment) -> bool:
+ """Add parameters to the query that are retrieved from training side.
+
+ Args:
+ environ (wsgi_types.Environment):
+ Required. The WSGI Environment.
+
+ Returns:
+ Whether the environment was successfully updated.
+ """
+ hosts = _get_hostnames()
+
+ if hosts is None:
+ return False
+
+ query_dict = {}
+ query_dict["service_addr"] = hosts
+
+ # Update service address and worker list
+ # Use parse_qsl and then convert list to dictionary so we can update
+ # attributes
+ prev_query_string = dict(parse.parse_qsl(environ["QUERY_STRING"]))
+ prev_query_string.update(query_dict)
+
+ environ["QUERY_STRING"] = parse.urlencode(prev_query_string)
+
+ return True
+
+
+def warn_tensorboard_env_var(var_name: str):
+ """Warns if a tensorboard related environment variable is missing.
+
+ Args:
+ var_name (str):
+ Required. The name of the missing environment variable.
+ """
+ logging.warning(
+ f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING
+ )
+
+
+def _check_env_vars() -> bool:
+ """Determine whether the correct environment variables are set.
+
+ Returns:
+ bool indicating all necessary variables are set.
+ """
+ # The below are tensorboard specific environment variables.
+ if environment_variables.tf_profiler_port is None:
+ warn_tensorboard_env_var("AIP_TF_PROFILER_PORT")
+ return False
+
+ if environment_variables.tensorboard_log_dir is None:
+ warn_tensorboard_env_var("AIP_TENSORBOARD_LOG_DIR")
+ return False
+
+ if environment_variables.tensorboard_api_uri is None:
+ warn_tensorboard_env_var("AIP_TENSORBOARD_API_URI")
+ return False
+
+ if environment_variables.tensorboard_resource_name is None:
+ warn_tensorboard_env_var("AIP_TENSORBOARD_RESOURCE_NAME")
+ return False
+
+ # These environment variables are not tensorboard related, they are
+ # variables set for any Vertex training run.
+ cluster_spec = environment_variables.cluster_spec
+ if cluster_spec is None:
+ logger.warning("Environment variable `CLUSTER_SPEC` is not set.")
+ return False
+
+ if environment_variables.cloud_ml_job_id is None:
+ logger.warning("Environment variable `CLOUD_ML_JOB_ID` is not set")
+ return False
+
+ return True
+
+
+class TFProfiler(base_plugin.BasePlugin):
+ """Handler for Tensorflow Profiling."""
+
+ PLUGIN_NAME = "profile"
+
+ def __init__(self):
+ """Build a TFProfiler object."""
+ context = _create_profiling_context()
+ self._profile_request_sender: profile_uploader.ProfileRequestSender = (
+ tensorboard_api.create_profile_request_sender()
+ )
+ self._profile_plugin: ProfilePlugin = ProfilePlugin(context)
+
+ def get_routes(
+ self,
+ ) -> Dict[str, Callable[[Dict[str, str], Callable[..., None]], Response]]:
+ """List of routes to serve.
+
+ Returns:
+ A callable that takes an werkzeug env and start response and returns a response.
+ """
+ return {"/capture_profile": self.capture_profile_wrapper}
+
+ # Define routes below
+ def capture_profile_wrapper(
+ self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse
+ ) -> Response:
+ """Take a request from tensorboard.gcp and run the profiling for the available servers.
+
+ Args:
+ environ (wsgi_types.Environment):
+ Required. The WSGI environment.
+ start_response (wsgi_types.StartResponse):
+ Required. The response callable provided by the WSGI server.
+
+ Returns:
+ A response iterable.
+ """
+ # The service address (localhost) and worker list are populated locally
+ if not _update_environ(environ):
+ err = {"error": "Could not parse the environ: %s"}
+ return Response(
+ json.dumps(err), content_type="application/json", status=500
+ )
+
+ response = self._profile_plugin.capture_route(environ, start_response)
+
+ self._profile_request_sender.send_request("")
+
+ return response
+
+ # End routes
+
+ @staticmethod
+ def setup() -> None:
+ """Sets up the plugin."""
+ tf.profiler.experimental.server.start(
+ int(environment_variables.tf_profiler_port)
+ )
+
+ @staticmethod
+ def post_setup_check() -> bool:
+ """Only chief and task 0 should run the webserver."""
+ cluster_spec = environment_variables.cluster_spec
+ task_type = cluster_spec.get("task", {}).get("type", "")
+ task_index = cluster_spec.get("task", {}).get("index", -1)
+
+ return task_type in {"workerpool0", "chief"} and task_index == 0
+
+ @staticmethod
+ def can_initialize() -> bool:
+ """Check that we can use the TF Profiler plugin.
+
+ This function checks a number of dependencies for the plugin to ensure we have the
+ right packages installed, the necessary versions, and the correct environment variables set.
+
+ Returns:
+ True if can initialize, False otherwise.
+ """
+
+ return _check_env_vars() and _check_tf()
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py b/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py
new file mode 100644
index 0000000000..3f7706bb34
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/webserver.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A basic webserver for hosting plugin routes."""
+
+import os
+
+from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types
+from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
+from typing import List
+from werkzeug import wrappers, Response
+
+
+class WebServer:
+ """A basic web server for handling requests."""
+
+ def __init__(self, plugins: List[base_plugin.BasePlugin]):
+ """Creates a web server to host plugin routes.
+
+ Args:
+ plugins (List[base_plugin.BasePlugin]):
+ Required. A list of `BasePlugin` objects.
+
+ Raises:
+ ValueError:
+ When there is an invalid route passed from
+ one of the plugins.
+ """
+
+ self._plugins = plugins
+ self._routes = {}
+
+ # Routes are in form {plugin_name}/{route}
+ for plugin in self._plugins:
+ for route, handler in plugin.get_routes().items():
+ if not route.startswith("/"):
+ raise ValueError(
+ 'Routes should start with a "/", '
+ "invalid route for plugin %s, route %s"
+ % (plugin.PLUGIN_NAME, route)
+ )
+
+ app_route = os.path.join("/", plugin.PLUGIN_NAME)
+
+ app_route += route
+ self._routes[app_route] = handler
+
+ def dispatch_request(
+ self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse
+ ) -> Response:
+ """Handles the routing of requests.
+
+ Args:
+ environ (wsgi_types.Environment):
+ Required. The WSGI environment.
+ start_response (wsgi_types.StartResponse):
+ Required. The response callable provided by the WSGI server.
+
+ Returns:
+ A response iterable.
+ """
+ # Check for existince of route
+ request = wrappers.Request(environ)
+
+ if request.path in self._routes:
+ return self._routes[request.path](environ, start_response)
+
+ response = wrappers.Response("Not Found", status=404)
+ return response(environ, start_response)
+
+ def wsgi_app(
+ self, environ: wsgi_types.Environment, start_response: wsgi_types.StartResponse
+ ) -> Response:
+ """Entrypoint for wsgi application.
+
+ Args:
+ environ (wsgi_types.Environment):
+ Required. The WSGI environment.
+ start_response (wsgi_types.StartResponse):
+ Required. The response callable provided by the WSGI server.
+
+ Returns:
+ A response iterable.
+ """
+ response = self.dispatch_request(environ, start_response)
+ return response
+
+ def __call__(self, environ, start_response):
+ """Entrypoint for wsgi application.
+
+ Args:
+ environ (wsgi_types.Environment):
+ Required. The WSGI environment.
+ start_response (wsgi_types.StartResponse):
+ Required. The response callable provided by the WSGI server.
+
+ Returns:
+ A response iterable.
+ """
+ return self.wsgi_app(environ, start_response)
diff --git a/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py b/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py
new file mode 100644
index 0000000000..0348c5b91e
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/cloud_profiler/wsgi_types.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Typing description for the WSGI App callables
+# For more information on WSGI, see PEP 3333
+
+from typing import Any, Dict, Text, Callable
+
+# Contain CGI environment variables, as defined by the Common Gateway Interface
+# specification.
+Environment = Dict[Text, Any]
+
+# Used to begin the HTTP response.
+StartResponse = Callable[..., Callable[[bytes], None]]
diff --git a/google/cloud/aiplatform/training_utils/environment_variables.py b/google/cloud/aiplatform/training_utils/environment_variables.py
new file mode 100644
index 0000000000..0783e78251
--- /dev/null
+++ b/google/cloud/aiplatform/training_utils/environment_variables.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Environment variables used in Vertex AI Training.
+
+import json
+import os
+
+from typing import Dict, Optional
+
+
+def _json_helper(env_var: str) -> Optional[Dict]:
+ """Helper to convert a dictionary represented as a string to a dictionary.
+
+ Args:
+ env_var (str):
+ Required. The name of the environment variable.
+
+ Returns:
+ A dictionary if the variable was found, None otherwise.
+ """
+ env = os.environ.get(env_var)
+ if env is not None:
+ return json.loads(env)
+ else:
+ return None
+
+
+# Cloud Storage URI of a directory intended for training data.
+training_data_uri = os.environ.get("AIP_TRAINING_DATA_URI")
+
+# Cloud Storage URI of a directory intended for validation data.
+validation_data_uri = os.environ.get("AIP_VALIDATION_DATA_URI")
+
+# Cloud Storage URI of a directory intended for test data.
+test_data_uri = os.environ.get("AIP_TEST_DATA_URI")
+
+# Cloud Storage URI of a directory intended for saving model artefacts.
+model_dir = os.environ.get("AIP_MODEL_DIR")
+
+# Cloud Storage URI of a directory intended for saving checkpoints.
+checkpoint_dir = os.environ.get("AIP_CHECKPOINT_DIR")
+
+# Cloud Storage URI of a directory intended for saving TensorBoard logs.
+tensorboard_log_dir = os.environ.get("AIP_TENSORBOARD_LOG_DIR")
+
+# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#cluster-variables
+cluster_spec = _json_helper("CLUSTER_SPEC")
+
+# json string as described in https://cloud.google.com/ai-platform-unified/docs/training/distributed-training#tf-config
+tf_config = _json_helper("TF_CONFIG")
+
+# Profiler port used for capturing profiling samples.
+tf_profiler_port = os.environ.get("AIP_TF_PROFILER_PORT")
+
+# API URI used for the tensorboard uploader.
+tensorboard_api_uri = os.environ.get("AIP_TENSORBOARD_API_URI")
+
+# The name of the tensorboard resource, in the form:
+# `projects/{project_id}/locations/{location}/tensorboards/{tensorboard_name}`
+tensorboard_resource_name = os.environ.get("AIP_TENSORBOARD_RESOURCE_NAME")
+
+# The name given to the training job.
+cloud_ml_job_id = os.environ.get("CLOUD_ML_JOB_ID")
+
+# The HTTP Handler port to use to host the profiling webserver.
+http_handler_port = os.environ.get("AIP_HTTP_HANDLER_PORT")
diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py
index 4404defb21..9ec2f27779 100644
--- a/google/cloud/aiplatform/utils/__init__.py
+++ b/google/cloud/aiplatform/utils/__init__.py
@@ -19,10 +19,11 @@
import abc
import datetime
import pathlib
-from collections import namedtuple
import logging
import re
-from typing import Any, Match, Optional, Type, TypeVar, Tuple
+from typing import Any, Callable, Dict, Optional, Type, TypeVar, Tuple
+
+from google.protobuf import timestamp_pb2
from google.api_core import client_options
from google.api_core import gapic_v1
@@ -30,26 +31,36 @@
from google.cloud import storage
from google.cloud.aiplatform import compat
-from google.cloud.aiplatform import constants
+from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.compat.services import (
dataset_service_client_v1beta1,
endpoint_service_client_v1beta1,
+ featurestore_online_serving_service_client_v1beta1,
+ featurestore_service_client_v1beta1,
+ index_service_client_v1beta1,
+ index_endpoint_service_client_v1beta1,
job_service_client_v1beta1,
+ metadata_service_client_v1beta1,
model_service_client_v1beta1,
pipeline_service_client_v1beta1,
prediction_service_client_v1beta1,
- metadata_service_client_v1beta1,
tensorboard_service_client_v1beta1,
)
from google.cloud.aiplatform.compat.services import (
dataset_service_client_v1,
endpoint_service_client_v1,
+ featurestore_online_serving_service_client_v1,
+ featurestore_service_client_v1,
+ index_service_client_v1,
+ index_endpoint_service_client_v1,
job_service_client_v1,
+ metadata_service_client_v1,
model_service_client_v1,
pipeline_service_client_v1,
prediction_service_client_v1,
+ tensorboard_service_client_v1,
)
from google.cloud.aiplatform.compat.types import (
@@ -61,85 +72,55 @@
# v1beta1
dataset_service_client_v1beta1.DatasetServiceClient,
endpoint_service_client_v1beta1.EndpointServiceClient,
+ featurestore_online_serving_service_client_v1beta1.FeaturestoreOnlineServingServiceClient,
+ featurestore_service_client_v1beta1.FeaturestoreServiceClient,
+ index_service_client_v1beta1.IndexServiceClient,
+ index_endpoint_service_client_v1beta1.IndexEndpointServiceClient,
model_service_client_v1beta1.ModelServiceClient,
prediction_service_client_v1beta1.PredictionServiceClient,
pipeline_service_client_v1beta1.PipelineServiceClient,
job_service_client_v1beta1.JobServiceClient,
metadata_service_client_v1beta1.MetadataServiceClient,
+ tensorboard_service_client_v1beta1.TensorboardServiceClient,
# v1
dataset_service_client_v1.DatasetServiceClient,
endpoint_service_client_v1.EndpointServiceClient,
+ featurestore_online_serving_service_client_v1.FeaturestoreOnlineServingServiceClient,
+ featurestore_service_client_v1.FeaturestoreServiceClient,
+ metadata_service_client_v1.MetadataServiceClient,
model_service_client_v1.ModelServiceClient,
prediction_service_client_v1.PredictionServiceClient,
pipeline_service_client_v1.PipelineServiceClient,
job_service_client_v1.JobServiceClient,
+ tensorboard_service_client_v1.TensorboardServiceClient,
)
-RESOURCE_NAME_PATTERN = re.compile(
- r"^projects\/(?P[\w-]+)\/locations\/(?P[\w-]+)\/(?P[\w\-\/]+)\/(?P[\w-]+)$"
-)
-RESOURCE_ID_PATTERN = re.compile(r"^[\w-]+$")
-
-Fields = namedtuple("Fields", ["project", "location", "resource", "id"],)
-
-def _match_to_fields(match: Match) -> Optional[Fields]:
- """Normalize RegEx groups from resource name pattern Match to class
- Fields."""
- if not match:
- return None
-
- return Fields(
- project=match["project"],
- location=match["location"],
- resource=match["resource"],
- id=match["id"],
- )
-
-
-def validate_id(resource_id: str) -> bool:
- """Validate int64 resource ID number."""
- return bool(RESOURCE_ID_PATTERN.match(resource_id))
+RESOURCE_ID_PATTERN = re.compile(r"^[\w-]+$")
-def extract_fields_from_resource_name(
- resource_name: str, resource_noun: Optional[str] = None
-) -> Optional[Fields]:
- """Validates and returns extracted fields from a fully-qualified resource
- name. Returns None if name is invalid.
+def validate_id(resource_id: str):
+ """Validate resource ID.
Args:
- resource_name (str):
- Required. A fully-qualified Vertex AI resource name
+ resource_id (str): Resource id.
+ Raises:
+ ValueError: If resource id is not a valid format.
- resource_noun (str):
- A resource noun to validate the resource name against.
- For example, you would pass "datasets" to validate
- "projects/123/locations/us-central1/datasets/456".
- In the case of deeper naming structures, e.g.,
- "projects/123/locations/us-central1/metadataStores/123/contexts/456",
- you would pass "metadataStores/123/contexts" as the resource_noun.
- Returns:
- fields (Fields):
- A named tuple containing four extracted fields from a resource name:
- project, location, resource, and id. These fields can be used for
- subsequent method calls in the SDK.
"""
- fields = _match_to_fields(RESOURCE_NAME_PATTERN.match(resource_name))
-
- if not fields:
- return None
- if resource_noun and fields.resource != resource_noun:
- return None
-
- return fields
+ if not RESOURCE_ID_PATTERN.match(resource_id):
+ raise ValueError(f"Resource {resource_id} is not a valid resource id.")
def full_resource_name(
resource_name: str,
resource_noun: str,
+ parse_resource_name_method: Callable[[str], Dict[str, str]],
+ format_resource_name_method: Callable[..., str],
+ parent_resource_name_fields: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
+ resource_id_validator: Optional[Callable[[str], None]] = None,
) -> str:
"""Returns fully qualified resource name.
@@ -148,85 +129,92 @@ def full_resource_name(
Required. A fully-qualified Vertex AI resource name or
resource ID.
resource_noun (str):
- A resource noun to validate the resource name against.
+ Required. A resource noun to validate the resource name against.
For example, you would pass "datasets" to validate
"projects/123/locations/us-central1/datasets/456".
- In the case of deeper naming structures, e.g.,
- "projects/123/locations/us-central1/metadataStores/123/contexts/456",
- you would pass "metadataStores/123/contexts" as the resource_noun.
+ parse_resource_name_method (Callable[[str], Dict[str,str]]):
+ Required. Method that parses a resource name into its segment parts.
+ These are generally included with GAPIC clients.
+ format_resource_name_method (Callable[..., str]):
+ Required. Method that takes segment parts of resource names and returns
+ the formated resource name. These are generally included with GAPIC clients.
+ parent_resource_name_fields (Dict[str, str]):
+ Optional. Dictionary of segment parts where key is the resource noun and
+ values are the resource ids.
+ For example:
+ {
+ "metadataStores": "123"
+ }
project (str):
- Optional project to retrieve resource_noun from. If not set, project
+ Optional. project to retrieve resource_noun from. If not set, project
set in aiplatform.init will be used.
location (str):
- Optional location to retrieve resource_noun from. If not set, location
+ Optional. location to retrieve resource_noun from. If not set, location
set in aiplatform.init will be used.
+ resource_id_validator (Callable[str, None]):
+ Optional. Function that validates the resource ID. Overrides the default validator, validate_id.
+ Should take a resource ID as string and raise ValueError if invalid.
Returns:
resource_name (str):
A fully-qualified Vertex AI resource name.
-
- Raises:
- ValueError:
- If resource name, resource ID or project ID not provided.
"""
- validate_resource_noun(resource_noun)
# Fully qualified resource name, e.g., "projects/.../locations/.../datasets/12345" or
# "projects/.../locations/.../metadataStores/.../contexts/12345"
- valid_name = extract_fields_from_resource_name(
- resource_name=resource_name, resource_noun=resource_noun
- )
+ fields = parse_resource_name_method(resource_name)
+ if fields:
+ return resource_name
+
+ resource_id_validator = resource_id_validator or validate_id
user_project = project or initializer.global_config.project
user_location = location or initializer.global_config.location
- # Partial resource name (i.e. "12345") with known project and location
- if (
- not valid_name
- and validate_project(user_project)
- and validate_region(user_location)
- and validate_id(resource_name)
- ):
- resource_name = f"projects/{user_project}/locations/{user_location}/{resource_noun}/{resource_name}"
- # Invalid resource_name parameter
- elif not valid_name:
- raise ValueError(f"Please provide a valid {resource_noun[:-1]} name or ID")
-
- return resource_name
+ validate_region(user_location)
+ resource_id_validator(resource_name)
+
+ format_args = {
+ "location": user_location,
+ "project": user_project,
+ convert_camel_case_resource_noun_to_snake_case(resource_noun): resource_name,
+ }
+
+ if parent_resource_name_fields:
+ format_args.update(
+ {
+ convert_camel_case_resource_noun_to_snake_case(key): value
+ for key, value in parent_resource_name_fields.items()
+ }
+ )
+ return format_resource_name_method(**format_args)
-# TODO(b/172286889) validate resource noun
-def validate_resource_noun(resource_noun: str) -> bool:
- """Validates resource noun.
- Args:
- resource_noun: resource noun to validate
- Returns:
- bool: True if no errors raised
- Raises:
- ValueError: If resource noun not supported.
- """
- if resource_noun:
- return True
- raise ValueError("Please provide a valid resource noun")
+# Resource nouns that are not plural in their resource names.
+# Userd below to avoid conversion from plural to singular.
+_SINGULAR_RESOURCE_NOUNS = {"time_series"}
+_SINGULAR_RESOURCE_NOUNS_MAP = {"indexes": "index"}
-# TODO(b/172288287) validate project
-def validate_project(project: str) -> bool:
- """Validates project.
+def convert_camel_case_resource_noun_to_snake_case(resource_noun: str) -> str:
+ """Converts camel case to snake case to map resource name parts to GAPIC parameter names.
Args:
- project: project to validate
+ resource_noun (str): The resource noun in camel case to covert.
Returns:
- bool: True if no errors raised
- Raises:
- ValueError: If project does not exist.
+ Singular snake case resource noun.
"""
- if project:
- return True
- raise ValueError("Please provide a valid project ID")
+ snake_case = re.sub("([A-Z]+)", r"_\1", resource_noun).lower()
+
+ # plural to singular
+ if snake_case in _SINGULAR_RESOURCE_NOUNS or not snake_case.endswith("s"):
+ return snake_case
+ elif snake_case in _SINGULAR_RESOURCE_NOUNS_MAP:
+ return _SINGULAR_RESOURCE_NOUNS_MAP[snake_case]
+ else:
+ return snake_case[:-1]
-# TODO(b/172932277) verify display name only contains utf-8 chars
def validate_display_name(display_name: str):
"""Verify display name is at most 128 chars.
@@ -239,6 +227,22 @@ def validate_display_name(display_name: str):
raise ValueError("Display name needs to be less than 128 characters.")
+def validate_labels(labels: Dict[str, str]):
+ """Validate labels.
+
+ Args:
+ labels: labels to verify
+ Raises:
+ ValueError: if labels is not a mapping of string key value pairs.
+ """
+ for k, v in labels.items():
+ if not isinstance(k, str) or not isinstance(v, str):
+ raise ValueError(
+ "Expect labels to be a mapping of string key value pairs. "
+ 'Got "{}".'.format(labels)
+ )
+
+
def validate_region(region: str) -> bool:
"""Validates region against supported regions.
@@ -373,6 +377,16 @@ def _default_version(self) -> str:
def _version_map(self) -> Tuple:
pass
+ @property
+ def api_endpoint(self) -> str:
+ """Default API endpoint used by this client."""
+ client = self._clients[self._default_version]
+
+ if self._is_temporary:
+ return client._client_options.api_endpoint
+ else:
+ return client._transport._host.split(":")[0]
+
def __init__(
self,
client_options: client_options.ClientOptions,
@@ -413,6 +427,22 @@ def __getattr__(self, name: str) -> Any:
def select_version(self, version: str) -> VertexAiServiceClient:
return self._clients[version]
+ @classmethod
+ def get_gapic_client_class(
+ cls, version: Optional[str] = None
+ ) -> Type[VertexAiServiceClient]:
+ """Gets the underyilng GAPIC client.
+
+ Used to access class and static methods without instantiating.
+
+ Args:
+ version (str):
+ Optional. Version of client to retreive otherwise the default version is returned.
+ Retuns:
+ Underlying GAPIC client for this wrapper and version.
+ """
+ return dict(cls._version_map)[version or cls._default_version]
+
class DatasetClientWithOverride(ClientWithOverride):
_is_temporary = True
@@ -432,6 +462,51 @@ class EndpointClientWithOverride(ClientWithOverride):
)
+class IndexClientWithOverride(ClientWithOverride):
+ _is_temporary = True
+ _default_version = compat.DEFAULT_VERSION
+ _version_map = (
+ (compat.V1, index_service_client_v1.IndexServiceClient),
+ (compat.V1BETA1, index_service_client_v1beta1.IndexServiceClient),
+ )
+
+
+class IndexEndpointClientWithOverride(ClientWithOverride):
+ _is_temporary = True
+ _default_version = compat.DEFAULT_VERSION
+ _version_map = (
+ (compat.V1, index_endpoint_service_client_v1.IndexEndpointServiceClient),
+ (
+ compat.V1BETA1,
+ index_endpoint_service_client_v1beta1.IndexEndpointServiceClient,
+ ),
+ )
+
+
+class FeaturestoreClientWithOverride(ClientWithOverride):
+ _is_temporary = True
+ _default_version = compat.DEFAULT_VERSION
+ _version_map = (
+ (compat.V1, featurestore_service_client_v1.FeaturestoreServiceClient),
+ (compat.V1BETA1, featurestore_service_client_v1beta1.FeaturestoreServiceClient),
+ )
+
+
+class FeaturestoreOnlineServingClientWithOverride(ClientWithOverride):
+ _is_temporary = False
+ _default_version = compat.DEFAULT_VERSION
+ _version_map = (
+ (
+ compat.V1,
+ featurestore_online_serving_service_client_v1.FeaturestoreOnlineServingServiceClient,
+ ),
+ (
+ compat.V1BETA1,
+ featurestore_online_serving_service_client_v1beta1.FeaturestoreOnlineServingServiceClient,
+ ),
+ )
+
+
class JobClientWithOverride(ClientWithOverride):
_is_temporary = True
_default_version = compat.DEFAULT_VERSION
@@ -459,6 +534,15 @@ class PipelineClientWithOverride(ClientWithOverride):
)
+class PipelineJobClientWithOverride(ClientWithOverride):
+ _is_temporary = True
+ _default_version = compat.DEFAULT_VERSION
+ _version_map = (
+ (compat.V1, pipeline_service_client_v1.PipelineServiceClient),
+ (compat.V1BETA1, pipeline_service_client_v1beta1.PipelineServiceClient),
+ )
+
+
class PredictionClientWithOverride(ClientWithOverride):
_is_temporary = False
_default_version = compat.DEFAULT_VERSION
@@ -470,16 +554,18 @@ class PredictionClientWithOverride(ClientWithOverride):
class MetadataClientWithOverride(ClientWithOverride):
_is_temporary = True
- _default_version = compat.V1BETA1
+ _default_version = compat.DEFAULT_VERSION
_version_map = (
+ (compat.V1, metadata_service_client_v1.MetadataServiceClient),
(compat.V1BETA1, metadata_service_client_v1beta1.MetadataServiceClient),
)
class TensorboardClientWithOverride(ClientWithOverride):
_is_temporary = False
- _default_version = compat.V1BETA1
+ _default_version = compat.DEFAULT_VERSION
_version_map = (
+ (compat.V1, tensorboard_service_client_v1.TensorboardServiceClient),
(compat.V1BETA1, tensorboard_service_client_v1beta1.TensorboardServiceClient),
)
@@ -488,9 +574,11 @@ class TensorboardClientWithOverride(ClientWithOverride):
"VertexAiServiceClientWithOverride",
DatasetClientWithOverride,
EndpointClientWithOverride,
+ FeaturestoreClientWithOverride,
JobClientWithOverride,
ModelClientWithOverride,
PipelineClientWithOverride,
+ PipelineJobClientWithOverride,
PredictionClientWithOverride,
MetadataClientWithOverride,
TensorboardClientWithOverride,
@@ -566,3 +654,25 @@ def _timestamped_copy_to_gcs(
gcs_path = "".join(["gs://", "/".join([blob.bucket.name, blob.name])])
return gcs_path
+
+
+def get_timestamp_proto(
+ time: Optional[datetime.datetime] = None,
+) -> timestamp_pb2.Timestamp:
+ """Gets timestamp proto of a given time.
+ Args:
+ time (datetime.datetime):
+ Optional. A user provided time. Default to datetime.datetime.now() if not given.
+ Returns:
+ timestamp_pb2.Timestamp: timestamp proto of the given time, not have higher than millisecond precision.
+ """
+ if not time:
+ time = datetime.datetime.now()
+
+ time_str = time.isoformat(sep=" ", timespec="milliseconds")
+ time = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S.%f")
+
+ timestamp_proto = timestamp_pb2.Timestamp()
+ timestamp_proto.FromDatetime(time)
+
+ return timestamp_proto
diff --git a/google/cloud/aiplatform/utils/column_transformations_utils.py b/google/cloud/aiplatform/utils/column_transformations_utils.py
new file mode 100644
index 0000000000..fe7c16983c
--- /dev/null
+++ b/google/cloud/aiplatform/utils/column_transformations_utils.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Dict, List, Optional, Tuple
+import warnings
+
+from google.cloud.aiplatform import datasets
+
+
+def get_default_column_transformations(
+ dataset: datasets._ColumnNamesDataset,
+ target_column: str,
+) -> Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
+ """Get default column transformations from the column names, while omitting the target column.
+
+ Args:
+ dataset (_ColumnNamesDataset):
+ Required. The dataset
+ target_column (str):
+ Required. The name of the column values of which the Model is to predict.
+
+ Returns:
+ Tuple[List[Dict[str, Dict[str, str]]], List[str]]:
+ The default column transformations and the default column names.
+ """
+
+ column_names = [
+ column_name
+ for column_name in dataset.column_names
+ if column_name != target_column
+ ]
+ column_transformations = [
+ {"auto": {"column_name": column_name}} for column_name in column_names
+ ]
+
+ return (column_transformations, column_names)
+
+
+def validate_and_get_column_transformations(
+ column_specs: Optional[Dict[str, str]],
+ column_transformations: Optional[List[Dict[str, Dict[str, str]]]],
+) -> List[Dict[str, Dict[str, str]]]:
+ """Validates column specs and transformations, then returns processed transformations.
+
+ Args:
+ column_specs (Dict[str, str]):
+ Optional. Alternative to column_transformations where the keys of the dict
+ are column names and their respective values are one of
+ AutoMLTabularTrainingJob.column_data_types.
+ When creating transformation for BigQuery Struct column, the column
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
+ If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have
+ no transformations defined on.
+ Only one of column_transformations or column_specs should be passed.
+ column_transformations (List[Dict[str, Dict[str, str]]]):
+ Optional. Transformations to apply to the input columns (i.e. columns other
+ than the targetColumn). Each transformation may produce multiple
+ result values from the column's value, and all are used for training.
+ When creating transformation for BigQuery Struct column, the column
+ should be flattened using "." as the delimiter. Only columns with no child
+ should have a transformation.
+ If an input column has no transformations on it, such a column is
+ ignored by the training, except for the targetColumn, which should have
+ no transformations defined on.
+ Only one of column_transformations or column_specs should be passed.
+ Consider using column_specs as column_transformations will be deprecated eventually.
+
+ Returns:
+ List[Dict[str, Dict[str, str]]]:
+ The column transformations.
+
+ Raises:
+ ValueError: If both column_transformations and column_specs were provided.
+ """
+ # user populated transformations
+ if column_transformations is not None and column_specs is not None:
+ raise ValueError(
+ "Both column_transformations and column_specs were passed. Only one is allowed."
+ )
+ if column_transformations is not None:
+ warnings.simplefilter("always", DeprecationWarning)
+ warnings.warn(
+ "consider using column_specs instead. column_transformations will be deprecated in the future.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return column_transformations
+ elif column_specs is not None:
+ return [
+ {transformation: {"column_name": column_name}}
+ for column_name, transformation in column_specs.items()
+ ]
+ else:
+ return None
diff --git a/google/cloud/aiplatform/utils/console_utils.py b/google/cloud/aiplatform/utils/console_utils.py
new file mode 100644
index 0000000000..c108b0605e
--- /dev/null
+++ b/google/cloud/aiplatform/utils/console_utils.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform import jobs
+from google.cloud.aiplatform import tensorboard
+
+
+def custom_job_console_uri(custom_job_resource_name: str) -> str:
+ """Helper method to create console uri from custom job resource name."""
+ fields = jobs.CustomJob._parse_resource_name(custom_job_resource_name)
+ return f"https://console.cloud.google.com/ai/platform/locations/{fields['location']}/training/{fields['custom_job']}?project={fields['project']}"
+
+
+def custom_job_tensorboard_console_uri(
+ tensorboard_resource_name: str, custom_job_resource_name: str
+) -> str:
+ """Helper method to create console uri to tensorboard from custom job resource."""
+ # projects+40556267596+locations+us-central1+tensorboards+740208820004847616+experiments+2214368039829241856
+ fields = tensorboard.Tensorboard._parse_resource_name(tensorboard_resource_name)
+ experiment_resource_name = f"{tensorboard_resource_name}/experiments/{custom_job_resource_name.split('/')[-1]}"
+ uri_experiment_resource_name = experiment_resource_name.replace("/", "+")
+ return f"https://{fields['location']}.tensorboard.googleusercontent.com/experiment/{uri_experiment_resource_name}"
diff --git a/google/cloud/aiplatform/utils/enhanced_library/__init__.py b/google/cloud/aiplatform/utils/enhanced_library/__init__.py
new file mode 100644
index 0000000000..7e1ec16ec8
--- /dev/null
+++ b/google/cloud/aiplatform/utils/enhanced_library/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/google/cloud/aiplatform/helpers/_decorators.py b/google/cloud/aiplatform/utils/enhanced_library/_decorators.py
similarity index 97%
rename from google/cloud/aiplatform/helpers/_decorators.py
rename to google/cloud/aiplatform/utils/enhanced_library/_decorators.py
index 95aac31c4f..43e395393b 100644
--- a/google/cloud/aiplatform/helpers/_decorators.py
+++ b/google/cloud/aiplatform/utils/enhanced_library/_decorators.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
-from google.cloud.aiplatform.helpers import value_converter
+from google.cloud.aiplatform.utils.enhanced_library import value_converter
from proto.marshal import Marshal
from proto.marshal.rules.struct import ValueRule
diff --git a/google/cloud/aiplatform/helpers/value_converter.py b/google/cloud/aiplatform/utils/enhanced_library/value_converter.py
similarity index 100%
rename from google/cloud/aiplatform/helpers/value_converter.py
rename to google/cloud/aiplatform/utils/enhanced_library/value_converter.py
diff --git a/google/cloud/aiplatform/utils/featurestore_utils.py b/google/cloud/aiplatform/utils/featurestore_utils.py
new file mode 100644
index 0000000000..b57824e15f
--- /dev/null
+++ b/google/cloud/aiplatform/utils/featurestore_utils.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import re
+from typing import Dict, NamedTuple, Optional
+
+from google.cloud.aiplatform.compat.services import featurestore_service_client
+from google.cloud.aiplatform.compat.types import (
+ feature as gca_feature,
+ featurestore_service as gca_featurestore_service,
+)
+from google.cloud.aiplatform import utils
+
+CompatFeaturestoreServiceClient = featurestore_service_client.FeaturestoreServiceClient
+
+RESOURCE_ID_PATTERN_REGEX = r"[a-z_][a-z0-9_]{0,59}"
+GCS_SOURCE_TYPE = {"csv", "avro"}
+GCS_DESTINATION_TYPE = {"csv", "tfrecord"}
+
+_FEATURE_VALUE_TYPE_UNSPECIFIED = "VALUE_TYPE_UNSPECIFIED"
+
+FEATURE_STORE_VALUE_TYPE_TO_BQ_DATA_TYPE_MAP = {
+ "BOOL": {"field_type": "BOOL"},
+ "BOOL_ARRAY": {"field_type": "BOOL", "mode": "REPEATED"},
+ "DOUBLE": {"field_type": "FLOAT64"},
+ "DOUBLE_ARRAY": {"field_type": "FLOAT64", "mode": "REPEATED"},
+ "INT64": {"field_type": "INT64"},
+ "INT64_ARRAY": {"field_type": "INT64", "mode": "REPEATED"},
+ "STRING": {"field_type": "STRING"},
+ "STRING_ARRAY": {"field_type": "STRING", "mode": "REPEATED"},
+ "BYTES": {"field_type": "BYTES"},
+}
+
+
+def validate_id(resource_id: str) -> None:
+ """Validates feature store resource ID pattern.
+
+ Args:
+ resource_id (str):
+ Required. Feature Store resource ID.
+
+ Raises:
+ ValueError if resource_id is invalid.
+ """
+ if not re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(resource_id):
+ raise ValueError("Resource ID {resource_id} is not a valied resource id.")
+
+
+def validate_feature_id(feature_id: str) -> None:
+ """Validates feature ID.
+
+ Args:
+ feature_id (str):
+ Required. Feature resource ID.
+
+ Raises:
+ ValueError if feature_id is invalid.
+ """
+ match = re.compile(r"^" + RESOURCE_ID_PATTERN_REGEX + r"$").match(feature_id)
+
+ if not match:
+ raise ValueError(
+ f"The value of feature_id may be up to 60 characters, and valid characters are `[a-z0-9_]`. "
+ f"The first character cannot be a number. Instead, get {feature_id}."
+ )
+
+ reserved_words = ["entity_id", "feature_timestamp", "arrival_timestamp"]
+ if feature_id.lower() in reserved_words:
+ raise ValueError(
+ "The feature_id can not be any of the reserved_words: `%s`"
+ % ("`, `".join(reserved_words))
+ )
+
+
+def validate_value_type(value_type: str) -> None:
+ """Validates user provided feature value_type string.
+
+ Args:
+ value_type (str):
+ Required. Immutable. Type of Feature value.
+ One of BOOL, BOOL_ARRAY, DOUBLE, DOUBLE_ARRAY, INT64, INT64_ARRAY, STRING, STRING_ARRAY, BYTES.
+
+ Raises:
+ ValueError if value_type is invalid or unspecified.
+ """
+ if getattr(gca_feature.Feature.ValueType, value_type, None) in (
+ gca_feature.Feature.ValueType.VALUE_TYPE_UNSPECIFIED,
+ None,
+ ):
+ raise ValueError(
+ f"Given value_type `{value_type}` invalid or unspecified. "
+ f"Choose one of {gca_feature.Feature.ValueType._member_names_} except `{_FEATURE_VALUE_TYPE_UNSPECIFIED}`"
+ )
+
+
+class _FeatureConfig(NamedTuple):
+ """Configuration for feature creation.
+
+ Usage:
+
+ config = _FeatureConfig(
+ feature_id='my_feature_id',
+ value_type='int64',
+ description='my description',
+ labels={'my_key': 'my_value'},
+ )
+ """
+
+ feature_id: str
+ value_type: str = _FEATURE_VALUE_TYPE_UNSPECIFIED
+ description: Optional[str] = None
+ labels: Optional[Dict[str, str]] = None
+
+ def _get_feature_id(self) -> str:
+ """Validates and returns the feature_id.
+
+ Returns:
+ str - valid feature ID.
+
+ Raise:
+ ValueError if feature_id is invalid
+ """
+
+ # Raises ValueError if invalid feature_id
+ validate_feature_id(feature_id=self.feature_id)
+
+ return self.feature_id
+
+ def _get_value_type_enum(self) -> int:
+ """Validates value_type and returns the enum of the value type.
+
+ Returns:
+ int - valid value type enum.
+ """
+
+ # Raises ValueError if invalid value_type
+ validate_value_type(value_type=self.value_type)
+
+ value_type_enum = getattr(gca_feature.Feature.ValueType, self.value_type)
+
+ return value_type_enum
+
+ def get_create_feature_request(
+ self,
+ ) -> gca_featurestore_service.CreateFeatureRequest:
+ """Return create feature request."""
+
+ gapic_feature = gca_feature.Feature(
+ value_type=self._get_value_type_enum(),
+ )
+
+ if self.labels:
+ utils.validate_labels(self.labels)
+ gapic_feature.labels = self.labels
+
+ if self.description:
+ gapic_feature.description = self.description
+
+ create_feature_request = gca_featurestore_service.CreateFeatureRequest(
+ feature=gapic_feature, feature_id=self._get_feature_id()
+ )
+
+ return create_feature_request
diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py
new file mode 100644
index 0000000000..855b7991f1
--- /dev/null
+++ b/google/cloud/aiplatform/utils/gcs_utils.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import datetime
+import glob
+import logging
+import pathlib
+from typing import Optional
+
+from google.auth import credentials as auth_credentials
+from google.cloud import storage
+
+from google.cloud.aiplatform import initializer
+
+
+_logger = logging.getLogger(__name__)
+
+
+def upload_to_gcs(
+ source_path: str,
+ destination_uri: str,
+ project: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+):
+ """Uploads local files to GCS.
+
+ After upload the `destination_uri` will contain the same data as the `source_path`.
+
+ Args:
+ source_path: Required. Path of the local data to copy to GCS.
+ destination_uri: Required. GCS URI where the data should be uploaded.
+ project: Optional. Google Cloud Project that contains the staging bucket.
+ credentials: The custom credentials to use when making API calls.
+ If not provided, default credentials will be used.
+
+ Raises:
+ RuntimeError: When source_path does not exist.
+ GoogleCloudError: When the upload process fails.
+ """
+ source_path_obj = pathlib.Path(source_path)
+ if not source_path_obj.exists():
+ raise RuntimeError(f"Source path does not exist: {source_path}")
+
+ project = project or initializer.global_config.project
+ credentials = credentials or initializer.global_config.credentials
+
+ storage_client = storage.Client(project=project, credentials=credentials)
+ if source_path_obj.is_dir():
+ source_file_paths = glob.glob(
+ pathname=str(source_path_obj / "**"), recursive=True
+ )
+ for source_file_path in source_file_paths:
+ source_file_path_obj = pathlib.Path(source_file_path)
+ if source_file_path_obj.is_dir():
+ continue
+ source_file_relative_path_obj = source_file_path_obj.relative_to(
+ source_path_obj
+ )
+ source_file_relative_posix_path = source_file_relative_path_obj.as_posix()
+ destination_file_uri = (
+ destination_uri.rstrip("/") + "/" + source_file_relative_posix_path
+ )
+ _logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
+ destination_blob = storage.Blob.from_string(
+ destination_file_uri, client=storage_client
+ )
+ destination_blob.upload_from_filename(filename=source_file_path)
+ else:
+ source_file_path = source_path
+ destination_file_uri = destination_uri
+ _logger.debug(f'Uploading "{source_file_path}" to "{destination_file_uri}"')
+ destination_blob = storage.Blob.from_string(
+ destination_file_uri, client=storage_client
+ )
+ destination_blob.upload_from_filename(filename=source_file_path)
+
+
+def stage_local_data_in_gcs(
+ data_path: str,
+ staging_gcs_dir: Optional[str] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+) -> str:
+ """Stages a local data in GCS.
+
+ The file copied to GCS is the name of the local file prepended with an
+ "aiplatform-{timestamp}-" string.
+
+ Args:
+ data_path: Required. Path of the local data to copy to GCS.
+ staging_gcs_dir:
+ Optional. Google Cloud Storage bucket to be used for data staging.
+ project: Optional. Google Cloud Project that contains the staging bucket.
+ location: Optional. Google Cloud location to use for the staging bucket.
+ credentials: The custom credentials to use when making API calls.
+ If not provided, default credentials will be used.
+
+ Returns:
+ Google Cloud Storage URI of the staged data.
+
+ Raises:
+ RuntimeError: When source_path does not exist.
+ GoogleCloudError: When the upload process fails.
+ """
+ data_path_obj = pathlib.Path(data_path)
+
+ if not data_path_obj.exists():
+ raise RuntimeError(f"Local data does not exist: data_path='{data_path}'")
+
+ staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
+ if not staging_gcs_dir:
+ project = project or initializer.global_config.project
+ location = location or initializer.global_config.location
+ credentials = credentials or initializer.global_config.credentials
+ # Creating the bucket if it does not exist.
+ # Currently we only do this when staging_gcs_dir is not specified.
+ # The buckets that we create are regional.
+ # This prevents errors when some service required regional bucket.
+ # E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` is in location `us`. It must be in the same regional location as the service location `us-central1`."
+ # We are making the bucket name region-specific since the bucket is regional.
+ staging_bucket_name = project + "-vertex-staging-" + location
+ client = storage.Client(project=project, credentials=credentials)
+ staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
+ if not staging_bucket.exists():
+ _logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
+ staging_bucket = client.create_bucket(
+ bucket_or_name=staging_bucket,
+ project=project,
+ location=location,
+ )
+ staging_gcs_dir = "gs://" + staging_bucket_name
+
+ timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")
+ staging_gcs_subdir = (
+ staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp
+ )
+
+ staged_data_uri = staging_gcs_subdir
+ if data_path_obj.is_file():
+ staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name
+
+ _logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"')
+ upload_to_gcs(
+ source_path=data_path,
+ destination_uri=staged_data_uri,
+ project=project,
+ credentials=credentials,
+ )
+
+ return staged_data_uri
diff --git a/google/cloud/aiplatform/utils/pipeline_utils.py b/google/cloud/aiplatform/utils/pipeline_utils.py
new file mode 100644
index 0000000000..f988cc307e
--- /dev/null
+++ b/google/cloud/aiplatform/utils/pipeline_utils.py
@@ -0,0 +1,256 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+import json
+from typing import Any, Dict, Mapping, Optional, Union
+from google.cloud.aiplatform.compat.types import pipeline_failure_policy
+import packaging.version
+
+
+class PipelineRuntimeConfigBuilder(object):
+ """Pipeline RuntimeConfig builder.
+
+ Constructs a RuntimeConfig spec with pipeline_root and parameter overrides.
+ """
+
+ def __init__(
+ self,
+ pipeline_root: str,
+ schema_version: str,
+ parameter_types: Mapping[str, str],
+ parameter_values: Optional[Dict[str, Any]] = None,
+ failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
+ ):
+ """Creates a PipelineRuntimeConfigBuilder object.
+
+ Args:
+ pipeline_root (str):
+ Required. The root of the pipeline outputs.
+ schema_version (str):
+ Required. Schema version of the IR. This field determines the fields supported in current version of IR.
+ parameter_types (Mapping[str, str]):
+ Required. The mapping from pipeline parameter name to its type.
+ parameter_values (Dict[str, Any]):
+ Optional. The mapping from runtime parameter name to its value.
+ failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
+ Optional. Represents the failure policy of a pipeline. Currently, the
+ default of a pipeline is that the pipeline will continue to
+ run until no more tasks can be executed, also known as
+ PIPELINE_FAILURE_POLICY_FAIL_SLOW. However, if a pipeline is
+ set to PIPELINE_FAILURE_POLICY_FAIL_FAST, it will stop
+ scheduling any new tasks when a task has failed. Any
+ scheduled tasks will continue to completion.
+ """
+ self._pipeline_root = pipeline_root
+ self._schema_version = schema_version
+ self._parameter_types = parameter_types
+ self._parameter_values = copy.deepcopy(parameter_values or {})
+ self._failure_policy = failure_policy
+
+ @classmethod
+ def from_job_spec_json(
+ cls,
+ job_spec: Mapping[str, Any],
+ ) -> "PipelineRuntimeConfigBuilder":
+ """Creates a PipelineRuntimeConfigBuilder object from PipelineJob json spec.
+
+ Args:
+ job_spec (Mapping[str, Any]):
+ Required. The PipelineJob spec.
+
+ Returns:
+ A PipelineRuntimeConfigBuilder object.
+ """
+ runtime_config_spec = job_spec["runtimeConfig"]
+ parameter_input_definitions = (
+ job_spec["pipelineSpec"]["root"]
+ .get("inputDefinitions", {})
+ .get("parameters", {})
+ )
+ schema_version = job_spec["pipelineSpec"]["schemaVersion"]
+
+ # 'type' is deprecated in IR and change to 'parameterType'.
+ parameter_types = {
+ k: v.get("parameterType") or v.get("type")
+ for k, v in parameter_input_definitions.items()
+ }
+
+ pipeline_root = runtime_config_spec.get("gcsOutputDirectory")
+ parameter_values = _parse_runtime_parameters(runtime_config_spec)
+ failure_policy = runtime_config_spec.get("failurePolicy")
+ return cls(
+ pipeline_root,
+ schema_version,
+ parameter_types,
+ parameter_values,
+ failure_policy,
+ )
+
+ def update_pipeline_root(self, pipeline_root: Optional[str]) -> None:
+ """Updates pipeline_root value.
+
+ Args:
+ pipeline_root (str):
+ Optional. The root of the pipeline outputs.
+ """
+ if pipeline_root:
+ self._pipeline_root = pipeline_root
+
+ def update_runtime_parameters(
+ self, parameter_values: Optional[Mapping[str, Any]] = None
+ ) -> None:
+ """Merges runtime parameter values.
+
+ Args:
+ parameter_values (Mapping[str, Any]):
+ Optional. The mapping from runtime parameter names to its values.
+ """
+ if parameter_values:
+ parameters = dict(parameter_values)
+ if packaging.version.parse(self._schema_version) <= packaging.version.parse(
+ "2.0.0"
+ ):
+ for k, v in parameter_values.items():
+ if isinstance(v, (dict, list, bool)):
+ parameters[k] = json.dumps(v)
+ self._parameter_values.update(parameters)
+
+ def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
+ """Merges runtime failure policy.
+
+ Args:
+ failure_policy (str):
+ Optional. The failure policy - "slow" or "fast".
+
+ Raises:
+ ValueError: if failure_policy is not valid.
+ """
+ if failure_policy:
+ if failure_policy in _FAILURE_POLICY_TO_ENUM_VALUE:
+ self._failure_policy = _FAILURE_POLICY_TO_ENUM_VALUE[failure_policy]
+ else:
+ raise ValueError(
+ f'failure_policy should be either "slow" or "fast", but got: "{failure_policy}".'
+ )
+
+ def build(self) -> Dict[str, Any]:
+ """Build a RuntimeConfig proto.
+
+ Raises:
+ ValueError: if the pipeline root is not specified.
+ """
+ if not self._pipeline_root:
+ raise ValueError(
+ "Pipeline root must be specified, either during "
+ "compile time, or when calling the service."
+ )
+ if packaging.version.parse(self._schema_version) > packaging.version.parse(
+ "2.0.0"
+ ):
+ parameter_values_key = "parameterValues"
+ else:
+ parameter_values_key = "parameters"
+
+ runtime_config = {
+ "gcsOutputDirectory": self._pipeline_root,
+ parameter_values_key: {
+ k: self._get_vertex_value(k, v)
+ for k, v in self._parameter_values.items()
+ if v is not None
+ },
+ }
+
+ if self._failure_policy:
+ runtime_config["failurePolicy"] = self._failure_policy
+
+ return runtime_config
+
+ def _get_vertex_value(
+ self, name: str, value: Union[int, float, str, bool, list, dict]
+ ) -> Union[int, float, str, bool, list, dict]:
+ """Converts primitive values into Vertex pipeline Value proto message.
+
+ Args:
+ name (str):
+ Required. The name of the pipeline parameter.
+ value (Union[int, float, str, bool, list, dict]):
+ Required. The value of the pipeline parameter.
+
+ Returns:
+ A dictionary represents the Vertex pipeline Value proto message.
+
+ Raises:
+ ValueError: if the parameter name is not found in pipeline root
+ inputs, or value is none.
+ """
+ if value is None:
+ raise ValueError("None values should be filtered out.")
+
+ if name not in self._parameter_types:
+ raise ValueError(
+ "The pipeline parameter {} is not found in the "
+ "pipeline job input definitions.".format(name)
+ )
+
+ if packaging.version.parse(self._schema_version) <= packaging.version.parse(
+ "2.0.0"
+ ):
+ result = {}
+ if self._parameter_types[name] == "INT":
+ result["intValue"] = value
+ elif self._parameter_types[name] == "DOUBLE":
+ result["doubleValue"] = value
+ elif self._parameter_types[name] == "STRING":
+ result["stringValue"] = value
+ else:
+ raise TypeError("Got unknown type of value: {}".format(value))
+ return result
+ else:
+ return value
+
+
+def _parse_runtime_parameters(
+ runtime_config_spec: Mapping[str, Any]
+) -> Optional[Dict[str, Any]]:
+ """Extracts runtime parameters from runtime config json spec.
+
+ Raises:
+ TypeError: if the parameter type is not one of 'INT', 'DOUBLE', 'STRING'.
+ """
+ # 'parameters' are deprecated in IR and changed to 'parameterValues'.
+ if runtime_config_spec.get("parameterValues") is not None:
+ return runtime_config_spec.get("parameterValues")
+
+ if runtime_config_spec.get("parameters") is not None:
+ result = {}
+ for name, value in runtime_config_spec.get("parameters").items():
+ if "intValue" in value:
+ result[name] = int(value["intValue"])
+ elif "doubleValue" in value:
+ result[name] = float(value["doubleValue"])
+ elif "stringValue" in value:
+ result[name] = value["stringValue"]
+ else:
+ raise TypeError("Got unknown type of value: {}".format(value))
+ return result
+
+
+_FAILURE_POLICY_TO_ENUM_VALUE = {
+ "slow": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_SLOW,
+ "fast": pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_FAIL_FAST,
+ None: pipeline_failure_policy.PipelineFailurePolicy.PIPELINE_FAILURE_POLICY_UNSPECIFIED,
+}
diff --git a/google/cloud/aiplatform/utils/resource_manager_utils.py b/google/cloud/aiplatform/utils/resource_manager_utils.py
new file mode 100644
index 0000000000..f918c766bf
--- /dev/null
+++ b/google/cloud/aiplatform/utils/resource_manager_utils.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from typing import Optional
+
+from google.auth import credentials as auth_credentials
+from google.cloud import resourcemanager
+
+from google.cloud.aiplatform import initializer
+
+
+def get_project_id(
+ project_number: str,
+ credentials: Optional[auth_credentials.Credentials] = None,
+) -> str:
+ """Gets project ID given the project number
+
+ Args:
+ project_number (str):
+ Required. The automatically generated unique identifier for your GCP project.
+ credentials: The custom credentials to use when making API calls.
+ Optional. If not provided, default credentials will be used.
+
+ Returns:
+ str - The unique string used to differentiate your GCP project from all others in Google Cloud.
+
+ """
+
+ credentials = credentials or initializer.global_config.credentials
+
+ projects_client = resourcemanager.ProjectsClient(credentials=credentials)
+
+ project = projects_client.get_project(name=f"projects/{project_number}")
+
+ return project.project_id
diff --git a/google/cloud/aiplatform/utils/rest_utils.py b/google/cloud/aiplatform/utils/rest_utils.py
new file mode 100644
index 0000000000..4d8db45c47
--- /dev/null
+++ b/google/cloud/aiplatform/utils/rest_utils.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from google.cloud.aiplatform import base
+
+
+def make_gcp_resource_rest_url(resource: base.VertexAiResourceNoun) -> str:
+ """Helper function to format the GCP resource url for google.X metadata schemas.
+
+ Args:
+ resource (base.VertexAiResourceNoun): Required. A Vertex resource instance.
+ Returns:
+ The formatted url of resource.
+ """
+ resource_name = resource.resource_name
+ version = resource.api_client._default_version
+ api_uri = resource.api_client.api_endpoint
+
+ return f"https://{api_uri}/{version}/{resource_name}"
diff --git a/google/cloud/aiplatform/utils/source_utils.py b/google/cloud/aiplatform/utils/source_utils.py
index b7fcef806f..dc3c14a759 100644
--- a/google/cloud/aiplatform/utils/source_utils.py
+++ b/google/cloud/aiplatform/utils/source_utils.py
@@ -16,6 +16,7 @@
import functools
+import os
import pathlib
import shutil
import subprocess
@@ -62,7 +63,7 @@ class _TrainingScriptPythonPackager:
Constant command to generate the source distribution package.
Attributes:
- script_path: local path of script to package
+ script_path: local path of script or folder to package
requirements: list of Python dependencies to add to package
Usage:
@@ -70,7 +71,7 @@ class _TrainingScriptPythonPackager:
packager = TrainingScriptPythonPackager('my_script.py', ['pandas', 'pytorch'])
gcs_path = packager.package_and_copy_to_gcs(
gcs_staging_dir='my-bucket',
- project='my-prject')
+ project='my-project')
module_name = packager.module_name
The package after installed can be executed as:
@@ -79,7 +80,6 @@ class _TrainingScriptPythonPackager:
_TRAINER_FOLDER = "trainer"
_ROOT_MODULE = "aiplatform_custom_trainer_script"
- _TASK_MODULE_NAME = "task"
_SETUP_PY_VERSION = "0.1"
_SETUP_PY_TEMPLATE = """from setuptools import find_packages
@@ -96,10 +96,12 @@ class _TrainingScriptPythonPackager:
_SETUP_PY_SOURCE_DISTRIBUTION_CMD = "setup.py sdist --formats=gztar"
- # Module name that can be executed during training. ie. python -m
- module_name = f"{_ROOT_MODULE}.{_TASK_MODULE_NAME}"
-
- def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = None):
+ def __init__(
+ self,
+ script_path: str,
+ task_module_name: str = "task",
+ requirements: Optional[Sequence[str]] = None,
+ ):
"""Initializes packager.
Args:
@@ -109,8 +111,14 @@ def __init__(self, script_path: str, requirements: Optional[Sequence[str]] = Non
"""
self.script_path = script_path
+ self.task_module_name = task_module_name
self.requirements = requirements or []
+ @property
+ def module_name(self) -> str:
+ # Module name that can be executed during training. ie. python -m
+ return f"{self._ROOT_MODULE}.{self.task_module_name}"
+
def make_package(self, package_directory: str) -> str:
"""Converts script into a Python package suitable for python module
execution.
@@ -134,9 +142,6 @@ def make_package(self, package_directory: str) -> str:
# __init__.py path in root module
init_path = trainer_path / "__init__.py"
- # The module that will contain the script
- script_out_path = trainer_path / f"{self._TASK_MODULE_NAME}.py"
-
# The path to setup.py in the package.
setup_py_path = trainer_root_path / "setup.py"
@@ -165,8 +170,18 @@ def make_package(self, package_directory: str) -> str:
with setup_py_path.open("w") as fp:
fp.write(setup_py_output)
- # Copy script as module of python package.
- shutil.copy(self.script_path, script_out_path)
+ if os.path.isdir(self.script_path):
+ # Remove destination path if it already exists
+ shutil.rmtree(trainer_path)
+
+ # Copy folder recursively
+ shutil.copytree(src=self.script_path, dst=trainer_path)
+ else:
+ # The module that will contain the script
+ script_out_path = trainer_path / f"{self.task_module_name}.py"
+
+ # Copy script as module of python package.
+ shutil.copy(self.script_path, script_out_path)
# Run setup.py to create the source distribution.
setup_cmd = [
diff --git a/google/cloud/aiplatform/utils/tensorboard_utils.py b/google/cloud/aiplatform/utils/tensorboard_utils.py
new file mode 100644
index 0000000000..32962aa1ab
--- /dev/null
+++ b/google/cloud/aiplatform/utils/tensorboard_utils.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Sequence, Dict
+from google.cloud.aiplatform_v1beta1.services.tensorboard_service.client import (
+ TensorboardServiceClient,
+)
+
+_SERVING_DOMAIN = "tensorboard.googleusercontent.com"
+
+
+def _parse_experiment_name(experiment_name: str) -> Dict[str, str]:
+ """Parses an experiment_name into its component segments.
+
+ Args:
+ experiment_name: Resource name of the TensorboardExperiment. E.g.
+ "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
+
+ Returns:
+ Components of the experiment name.
+
+ Raises:
+ ValueError: If the experiment_name is invalid.
+ """
+ matched = TensorboardServiceClient.parse_tensorboard_experiment_path(
+ experiment_name
+ )
+ if not matched:
+ raise ValueError(f"Invalid experiment name: {experiment_name}.")
+ return matched
+
+
+def get_experiment_url(experiment_name: str) -> str:
+ """Get URL for comparing experiments.
+
+ Args:
+ experiment_name: Resource name of the TensorboardExperiment. E.g.
+ "projects/123/locations/asia-east1/tensorboards/456/experiments/exp1"
+
+ Returns:
+ URL for the tensorboard web app.
+ """
+ location = _parse_experiment_name(experiment_name)["location"]
+ name_for_url = experiment_name.replace("/", "+")
+ return f"https://{location}.{_SERVING_DOMAIN}/experiment/{name_for_url}"
+
+
+def get_experiments_compare_url(experiment_names: Sequence[str]) -> str:
+ """Get URL for comparing experiments.
+
+ Args:
+ experiment_names: Resource names of the TensorboardExperiments that needs to
+ be compared.
+
+ Returns:
+ URL for the tensorboard web app.
+ """
+ if len(experiment_names) < 2:
+ raise ValueError("At least two experiment_names are required.")
+
+ locations = {
+ _parse_experiment_name(experiment_name)["location"]
+ for experiment_name in experiment_names
+ }
+ if len(locations) != 1:
+ raise ValueError(
+ f"Got experiments from different locations: {', '.join(locations)}."
+ )
+ location = locations.pop()
+
+ experiment_url_segments = []
+ for idx, experiment_name in enumerate(experiment_names):
+ name_segments = _parse_experiment_name(experiment_name)
+ experiment_url_segments.append(
+ "{cnt}-{experiment}:{project}+{location}+{tensorboard}+{experiment}".format(
+ cnt=idx + 1, **name_segments
+ )
+ )
+ encoded_names = ",".join(experiment_url_segments)
+ return f"https://{location}.{_SERVING_DOMAIN}/compare/{encoded_names}"
diff --git a/google/cloud/aiplatform/utils/worker_spec_utils.py b/google/cloud/aiplatform/utils/worker_spec_utils.py
index 385ac83979..2de1bf2f28 100644
--- a/google/cloud/aiplatform/utils/worker_spec_utils.py
+++ b/google/cloud/aiplatform/utils/worker_spec_utils.py
@@ -21,17 +21,31 @@
accelerator_type as gca_accelerator_type_compat,
)
+# `_SPEC_ORDERS` contains the worker pool spec type and its order in the `_WorkerPoolSpec`.
+# The `server_spec` supports either reduction server or parameter server, each
+# with different configuration for its `container_spec`. This mapping will be
+# used during configuration of `container_spec` for all worker pool specs.
+_SPEC_ORDERS = {
+ "chief_spec": 0,
+ "worker_spec": 1,
+ "server_spec": 2,
+ "evaluator_spec": 3,
+}
-class _MachineSpec(NamedTuple):
- """Specification container for Machine specs used for distributed training.
+
+class _WorkerPoolSpec(NamedTuple):
+ """Specification container for Worker Pool specs used for distributed training.
Usage:
- spec = _MachineSpec(
+ spec = _WorkerPoolSpec(
replica_count=10,
machine_type='n1-standard-4',
accelerator_count=2,
- accelerator_type='NVIDIA_TESLA_K80')
+ accelerator_type='NVIDIA_TESLA_K80',
+ boot_disk_type='pd-ssd',
+ boot_disk_size_gb=100,
+ )
Note that container and python package specs are not stored with this spec.
"""
@@ -40,6 +54,8 @@ class _MachineSpec(NamedTuple):
machine_type: str = "n1-standard-4"
accelerator_count: int = 0
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
+ boot_disk_type: str = "pd-ssd"
+ boot_disk_size_gb: int = 100
def _get_accelerator_type(self) -> Optional[str]:
"""Validates accelerator_type and returns the name of the accelerator.
@@ -70,7 +86,12 @@ def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]:
spec = {
"machine_spec": {"machine_type": self.machine_type},
"replica_count": self.replica_count,
+ "disk_spec": {
+ "boot_disk_type": self.boot_disk_type,
+ "boot_disk_size_gb": self.boot_disk_size_gb,
+ },
}
+
accelerator_type = self._get_accelerator_type()
if accelerator_type and self.accelerator_count:
spec["machine_spec"]["accelerator_type"] = accelerator_type
@@ -98,25 +119,29 @@ class _DistributedTrainingSpec(NamedTuple):
Usage:
dist_training_spec = _DistributedTrainingSpec(
- chief_spec = _MachineSpec(
+ chief_spec = _WorkerPoolSpec(
replica_count=1,
machine_type='n1-standard-4',
accelerator_count=2,
- accelerator_type='NVIDIA_TESLA_K80'
- ),
- worker_spec = _MachineSpec(
+ accelerator_type='NVIDIA_TESLA_K80',
+ boot_disk_type='pd-ssd',
+ boot_disk_size_gb=100,
+ ),
+ worker_spec = _WorkerPoolSpec(
replica_count=10,
machine_type='n1-standard-4',
accelerator_count=2,
- accelerator_type='NVIDIA_TESLA_K80'
- )
+ accelerator_type='NVIDIA_TESLA_K80',
+ boot_disk_type='pd-ssd',
+ boot_disk_size_gb=100,
+ ),
)
"""
- chief_spec: _MachineSpec = _MachineSpec()
- worker_spec: _MachineSpec = _MachineSpec()
- parameter_server_spec: _MachineSpec = _MachineSpec()
- evaluator_spec: _MachineSpec = _MachineSpec()
+ chief_spec: _WorkerPoolSpec = _WorkerPoolSpec()
+ worker_spec: _WorkerPoolSpec = _WorkerPoolSpec()
+ server_spec: _WorkerPoolSpec = _WorkerPoolSpec()
+ evaluator_spec: _WorkerPoolSpec = _WorkerPoolSpec()
@property
def pool_specs(
@@ -138,10 +163,10 @@ def pool_specs(
spec_order = [
self.chief_spec,
self.worker_spec,
- self.parameter_server_spec,
+ self.server_spec,
self.evaluator_spec,
]
- specs = [s.spec_dict for s in spec_order]
+ specs = [{} if s.is_empty else s.spec_dict for s in spec_order]
for i in reversed(range(len(spec_order))):
if spec_order[i].is_empty:
specs.pop()
@@ -156,8 +181,12 @@ def chief_worker_pool(
machine_type: str = "n1-standard-4",
accelerator_count: int = 0,
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED",
+ boot_disk_type: str = "pd-ssd",
+ boot_disk_size_gb: int = 100,
+ reduction_server_replica_count: int = 0,
+ reduction_server_machine_type: str = None,
) -> "_DistributedTrainingSpec":
- """Parameterizes Config to support only chief with worker replicas.
+ """Parametrizes Config to support only chief with worker replicas.
For replica is assigned to chief and the remainder to workers. All spec have the
same machine type, accelerator count, and accelerator type.
@@ -174,26 +203,51 @@ def chief_worker_pool(
NVIDIA_TESLA_T4
accelerator_count (int):
The number of accelerators to attach to a worker replica.
+ boot_disk_type (str):
+ Type of the boot disk (default is `pd-ssd`).
+ Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or
+ `pd-standard` (Persistent Disk Hard Disk Drive).
+ boot_disk_size_gb (int):
+ Size in GB of the boot disk (default is 100GB).
+ boot disk size must be within the range of [100, 64000].
+ reduction_server_replica_count (int):
+ The number of reduction server replicas, default is 0.
+ reduction_server_machine_type (str):
+ The type of machine to use for reduction server, default is `n1-highcpu-16`.
Returns:
- _DistributedTrainingSpec representing one chief and n workers all of same
- type. If replica_count <= 0 then an empty spec is returned.
+ _DistributedTrainingSpec representing one chief and n workers all of
+ same type, optional with reduction server(s). If replica_count <= 0
+ then an empty spec is returned.
"""
if replica_count <= 0:
return cls()
- chief_spec = _MachineSpec(
+ chief_spec = _WorkerPoolSpec(
replica_count=1,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
)
- worker_spec = _MachineSpec(
+ worker_spec = _WorkerPoolSpec(
replica_count=replica_count - 1,
machine_type=machine_type,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type,
+ boot_disk_type=boot_disk_type,
+ boot_disk_size_gb=boot_disk_size_gb,
)
- return cls(chief_spec=chief_spec, worker_spec=worker_spec)
+ reduction_server_spec = _WorkerPoolSpec(
+ replica_count=reduction_server_replica_count,
+ machine_type=reduction_server_machine_type,
+ )
+
+ return cls(
+ chief_spec=chief_spec,
+ worker_spec=worker_spec,
+ server_spec=reduction_server_spec,
+ )
diff --git a/google/cloud/aiplatform/utils/yaml_utils.py b/google/cloud/aiplatform/utils/yaml_utils.py
new file mode 100644
index 0000000000..bac33733dc
--- /dev/null
+++ b/google/cloud/aiplatform/utils/yaml_utils.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import re
+from typing import Any, Dict, Optional
+from urllib import request
+
+from google.auth import credentials as auth_credentials
+from google.auth import transport
+from google.cloud import storage
+
+# Pattern for an Artifact Registry URL.
+_VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*")
+
+
+def load_yaml(
+ path: str,
+ project: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+) -> Dict[str, Any]:
+ """Loads data from a YAML document.
+
+ Args:
+ path (str):
+ Required. The path of the YAML document in Google Cloud Storage or
+ local.
+ project (str):
+ Optional. Project to initiate the Storage client with.
+ credentials (auth_credentials.Credentials):
+ Optional. Credentials to use with Storage Client.
+
+ Returns:
+ A Dict object representing the YAML document.
+ """
+ if path.startswith("gs://"):
+ return _load_yaml_from_gs_uri(path, project, credentials)
+ elif _VALID_AR_URL.match(path):
+ return _load_yaml_from_ar_uri(path, credentials)
+ else:
+ return _load_yaml_from_local_file(path)
+
+
+def _load_yaml_from_gs_uri(
+ uri: str,
+ project: Optional[str] = None,
+ credentials: Optional[auth_credentials.Credentials] = None,
+) -> Dict[str, Any]:
+ """Loads data from a YAML document referenced by a GCS URI.
+
+ Args:
+ path (str):
+ Required. GCS URI for YAML document.
+ project (str):
+ Optional. Project to initiate the Storage client with.
+ credentials (auth_credentials.Credentials):
+ Optional. Credentials to use with Storage Client.
+
+ Returns:
+ A Dict object representing the YAML document.
+ """
+ try:
+ import yaml
+ except ImportError:
+ raise ImportError(
+ "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
+ )
+ storage_client = storage.Client(project=project, credentials=credentials)
+ blob = storage.Blob.from_string(uri, storage_client)
+ return yaml.safe_load(blob.download_as_bytes())
+
+
+def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
+ """Loads data from a YAML local file.
+
+ Args:
+ file_path (str):
+ Required. The local file path of the YAML document.
+
+ Returns:
+ A Dict object representing the YAML document.
+ """
+ try:
+ import yaml
+ except ImportError:
+ raise ImportError(
+ "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
+ )
+ with open(file_path) as f:
+ return yaml.safe_load(f)
+
+
+def _load_yaml_from_ar_uri(
+ uri: str,
+ credentials: Optional[auth_credentials.Credentials] = None,
+) -> Dict[str, Any]:
+ """Loads data from a YAML document referenced by a Artifact Registry URI.
+
+ Args:
+ path (str):
+ Required. Artifact Registry URI for YAML document.
+ credentials (auth_credentials.Credentials):
+ Optional. Credentials to use with Artifact Registry.
+
+ Returns:
+ A Dict object representing the YAML document.
+ """
+ try:
+ import yaml
+ except ImportError:
+ raise ImportError(
+ "pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
+ 'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
+ )
+ req = request.Request(uri)
+
+ if credentials:
+ if not credentials.valid:
+ credentials.refresh(transport.requests.Request())
+ if credentials.token:
+ req.add_header("Authorization", "Bearer " + credentials.token)
+ response = request.urlopen(req)
+
+ return yaml.safe_load(response.read().decode("utf-8"))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py
index 135e131a29..5d321a91dd 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py
index fdfe1ca46f..47708ddc7f 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py
index 744852e8a3..c36f147d50 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .image_classification import ImageClassificationPredictionInstance
-from .image_object_detection import ImageObjectDetectionPredictionInstance
-from .image_segmentation import ImageSegmentationPredictionInstance
-from .text_classification import TextClassificationPredictionInstance
-from .text_extraction import TextExtractionPredictionInstance
-from .text_sentiment import TextSentimentPredictionInstance
-from .video_action_recognition import VideoActionRecognitionPredictionInstance
-from .video_classification import VideoClassificationPredictionInstance
-from .video_object_tracking import VideoObjectTrackingPredictionInstance
+from .image_classification import (
+ ImageClassificationPredictionInstance,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionInstance,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionInstance,
+)
+from .text_classification import (
+ TextClassificationPredictionInstance,
+)
+from .text_extraction import (
+ TextExtractionPredictionInstance,
+)
+from .text_sentiment import (
+ TextSentimentPredictionInstance,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionInstance,
+)
+from .video_classification import (
+ VideoClassificationPredictionInstance,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionInstance,
+)
__all__ = (
"ImageClassificationPredictionInstance",
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py
index 04e7b841a5..0507c768a1 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,16 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"ImageClassificationPredictionInstance",},
+ manifest={
+ "ImageClassificationPredictionInstance",
+ },
)
class ImageClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Image Classification.
+
Attributes:
content (str):
- The image bytes or GCS URI to make the
- prediction on.
+ The image bytes or Cloud Storage URI to make
+ the prediction on.
mime_type (str):
The MIME type of the content of the image.
Only the images in below listed MIME types are
@@ -40,8 +43,14 @@ class ImageClassificationPredictionInstance(proto.Message):
- image/vnd.microsoft.icon
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py
index 5180c12ece..4f45ab1745 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,16 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"ImageObjectDetectionPredictionInstance",},
+ manifest={
+ "ImageObjectDetectionPredictionInstance",
+ },
)
class ImageObjectDetectionPredictionInstance(proto.Message):
r"""Prediction input format for Image Object Detection.
+
Attributes:
content (str):
- The image bytes or GCS URI to make the
- prediction on.
+ The image bytes or Cloud Storage URI to make
+ the prediction on.
mime_type (str):
The MIME type of the content of the image.
Only the images in below listed MIME types are
@@ -40,8 +43,14 @@ class ImageObjectDetectionPredictionInstance(proto.Message):
- image/vnd.microsoft.icon
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py
index 0591b17208..0f18a3e7bb 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"ImageSegmentationPredictionInstance",},
+ manifest={
+ "ImageSegmentationPredictionInstance",
+ },
)
class ImageSegmentationPredictionInstance(proto.Message):
r"""Prediction input format for Image Segmentation.
+
Attributes:
content (str):
The image bytes to make the predictions on.
@@ -34,8 +37,14 @@ class ImageSegmentationPredictionInstance(proto.Message):
- image/png
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py
index aafbcac3e7..27c3487fb2 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"TextClassificationPredictionInstance",},
+ manifest={
+ "TextClassificationPredictionInstance",
+ },
)
class TextClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Text Classification.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -33,8 +36,14 @@ class TextClassificationPredictionInstance(proto.Message):
- text/plain
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py
index ba1997ba05..88a9e61c67 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"TextExtractionPredictionInstance",},
+ manifest={
+ "TextExtractionPredictionInstance",
+ },
)
class TextExtractionPredictionInstance(proto.Message):
r"""Prediction input format for Text Extraction.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -36,15 +39,24 @@ class TextExtractionPredictionInstance(proto.Message):
If a key is provided, the batch prediction
result will by mapped to this key. If omitted,
then the batch prediction result will contain
- the entire input instance. AI Platform will not
+ the entire input instance. Vertex AI will not
check if keys in the request are duplicates, so
it is up to the caller to ensure the keys are
unique.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- key = proto.Field(proto.STRING, number=3,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ key = proto.Field(
+ proto.STRING,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py
index d86d58f40f..7907d6e31c 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"TextSentimentPredictionInstance",},
+ manifest={
+ "TextSentimentPredictionInstance",
+ },
)
class TextSentimentPredictionInstance(proto.Message):
r"""Prediction input format for Text Sentiment.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -33,8 +36,14 @@ class TextSentimentPredictionInstance(proto.Message):
- text/plain
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py
index d8db889408..c11836967d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"VideoActionRecognitionPredictionInstance",},
+ manifest={
+ "VideoActionRecognitionPredictionInstance",
+ },
)
class VideoActionRecognitionPredictionInstance(proto.Message):
r"""Prediction input format for Video Action Recognition.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoActionRecognitionPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py
index f03e673f90..30dbf42ca5 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"VideoClassificationPredictionInstance",},
+ manifest={
+ "VideoClassificationPredictionInstance",
+ },
)
class VideoClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Video Classification.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoClassificationPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py
index 5df1e42eb5..3bf0c6a275 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.instance",
- manifest={"VideoObjectTrackingPredictionInstance",},
+ manifest={
+ "VideoObjectTrackingPredictionInstance",
+ },
)
class VideoObjectTrackingPredictionInstance(proto.Message):
r"""Prediction input format for Video Object Tracking.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoObjectTrackingPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py
index a55ff6dc0f..803e6dd857 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py
index dcf74bb7a0..fd80646afd 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py
index 26997a8d81..135f3bff54 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .image_classification import ImageClassificationPredictionParams
-from .image_object_detection import ImageObjectDetectionPredictionParams
-from .image_segmentation import ImageSegmentationPredictionParams
-from .video_action_recognition import VideoActionRecognitionPredictionParams
-from .video_classification import VideoClassificationPredictionParams
-from .video_object_tracking import VideoObjectTrackingPredictionParams
+from .image_classification import (
+ ImageClassificationPredictionParams,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionParams,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionParams,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionParams,
+)
+from .video_classification import (
+ VideoClassificationPredictionParams,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionParams,
+)
__all__ = (
"ImageClassificationPredictionParams",
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py
index e042f39854..dfa355995e 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"ImageClassificationPredictionParams",},
+ manifest={
+ "ImageClassificationPredictionParams",
+ },
)
class ImageClassificationPredictionParams(proto.Message):
r"""Prediction model parameters for Image Classification.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -36,8 +39,14 @@ class ImageClassificationPredictionParams(proto.Message):
return fewer predictions. Default value is 10.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py
index 4ca8404d61..7ee9d4046f 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"ImageObjectDetectionPredictionParams",},
+ manifest={
+ "ImageObjectDetectionPredictionParams",
+ },
)
class ImageObjectDetectionPredictionParams(proto.Message):
r"""Prediction model parameters for Image Object Detection.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,8 +40,14 @@ class ImageObjectDetectionPredictionParams(proto.Message):
value is 10.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py
index 6a2102b808..a346c8b185 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"ImageSegmentationPredictionParams",},
+ manifest={
+ "ImageSegmentationPredictionParams",
+ },
)
class ImageSegmentationPredictionParams(proto.Message):
r"""Prediction model parameters for Image Segmentation.
+
Attributes:
confidence_threshold (float):
When the model predicts category of pixels of
@@ -33,7 +36,10 @@ class ImageSegmentationPredictionParams(proto.Message):
background. Default value is 0.5.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py
index f09d2058e3..4b85f006ae 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"VideoActionRecognitionPredictionParams",},
+ manifest={
+ "VideoActionRecognitionPredictionParams",
+ },
)
class VideoActionRecognitionPredictionParams(proto.Message):
r"""Prediction model parameters for Video Action Recognition.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,8 +40,14 @@ class VideoActionRecognitionPredictionParams(proto.Message):
Default value is 50.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py
index 1ab180bbe2..ea1748191d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"VideoClassificationPredictionParams",},
+ manifest={
+ "VideoClassificationPredictionParams",
+ },
)
class VideoClassificationPredictionParams(proto.Message):
r"""Prediction model parameters for Video Classification.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,16 +40,16 @@ class VideoClassificationPredictionParams(proto.Message):
10,000.
segment_classification (bool):
Set to true to request segment-level
- classification. AI Platform returns labels and
+ classification. Vertex AI returns labels and
their confidence scores for the entire time
segment of the video that user specified in the
input instance. Default value is true
shot_classification (bool):
Set to true to request shot-level
- classification. AI Platform determines the
+ classification. Vertex AI determines the
boundaries for each camera shot in the entire
time segment of the video that user specified in
- the input instance. AI Platform then returns
+ the input instance. Vertex AI then returns
labels and their confidence scores for each
detected shot, along with the start and end time
of the shot.
@@ -57,22 +60,36 @@ class VideoClassificationPredictionParams(proto.Message):
Default value is false
one_sec_interval_classification (bool):
Set to true to request classification for a
- video at one-second intervals. AI Platform
- returns labels and their confidence scores for
- each second of the entire time segment of the
- video that user specified in the input WARNING:
- Model evaluation is not done for this
- classification type, the quality of it depends
- on the training data, but there are no metrics
- provided to describe that quality. Default value
- is false
+ video at one-second intervals. Vertex AI returns
+ labels and their confidence scores for each
+ second of the entire time segment of the video
+ that user specified in the input WARNING: Model
+ evaluation is not done for this classification
+ type, the quality of it depends on the training
+ data, but there are no metrics provided to
+ describe that quality. Default value is false
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
- segment_classification = proto.Field(proto.BOOL, number=3,)
- shot_classification = proto.Field(proto.BOOL, number=4,)
- one_sec_interval_classification = proto.Field(proto.BOOL, number=5,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+ segment_classification = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
+ shot_classification = proto.Field(
+ proto.BOOL,
+ number=4,
+ )
+ one_sec_interval_classification = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py
index 83dedee1d9..a4cc867a44 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.params",
- manifest={"VideoObjectTrackingPredictionParams",},
+ manifest={
+ "VideoObjectTrackingPredictionParams",
+ },
)
class VideoObjectTrackingPredictionParams(proto.Message):
r"""Prediction model parameters for Video Object Tracking.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -41,9 +44,18 @@ class VideoObjectTrackingPredictionParams(proto.Message):
frame size are returned. Default value is 0.0.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
- min_bounding_box_size = proto.Field(proto.FLOAT, number=3,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+ min_bounding_box_size = proto.Field(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py
index a39dd71937..0bd9e1fc3b 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py
index 866cade4d0..73c5336994 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py
index 0bb99636b3..12fc5b9a2f 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .classification import ClassificationPredictionResult
-from .image_object_detection import ImageObjectDetectionPredictionResult
-from .image_segmentation import ImageSegmentationPredictionResult
-from .tabular_classification import TabularClassificationPredictionResult
-from .tabular_regression import TabularRegressionPredictionResult
-from .text_extraction import TextExtractionPredictionResult
-from .text_sentiment import TextSentimentPredictionResult
-from .video_action_recognition import VideoActionRecognitionPredictionResult
-from .video_classification import VideoClassificationPredictionResult
-from .video_object_tracking import VideoObjectTrackingPredictionResult
+from .classification import (
+ ClassificationPredictionResult,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionResult,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionResult,
+)
+from .tabular_classification import (
+ TabularClassificationPredictionResult,
+)
+from .tabular_regression import (
+ TabularRegressionPredictionResult,
+)
+from .text_extraction import (
+ TextExtractionPredictionResult,
+)
+from .text_sentiment import (
+ TextSentimentPredictionResult,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionResult,
+)
+from .video_classification import (
+ VideoClassificationPredictionResult,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionResult,
+)
__all__ = (
"ClassificationPredictionResult",
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py
index 490d81e91d..0cfd533f16 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"ClassificationPredictionResult",},
+ manifest={
+ "ClassificationPredictionResult",
+ },
)
class ClassificationPredictionResult(proto.Message):
r"""Prediction output format for Image and Text Classification.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
- had been identified, ordered by the confidence
- score descendingly.
+ had been identified.
display_names (Sequence[str]):
The display names of the AnnotationSpecs that
had been identified, order matches the IDs.
@@ -38,9 +40,18 @@ class ClassificationPredictionResult(proto.Message):
confidence. Order matches the Ids.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- confidences = proto.RepeatedField(proto.FLOAT, number=3,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py
index c44d4744a3..c1459a3e3a 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,12 +20,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"ImageObjectDetectionPredictionResult",},
+ manifest={
+ "ImageObjectDetectionPredictionResult",
+ },
)
class ImageObjectDetectionPredictionResult(proto.Message):
r"""Prediction output format for Image Object Detection.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
@@ -48,10 +51,23 @@ class ImageObjectDetectionPredictionResult(proto.Message):
image.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- confidences = proto.RepeatedField(proto.FLOAT, number=3,)
- bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct_pb2.ListValue,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=3,
+ )
+ bboxes = proto.RepeatedField(
+ proto.MESSAGE,
+ number=4,
+ message=struct_pb2.ListValue,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py
index 4608baeaf6..6c49da5628 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"ImageSegmentationPredictionResult",},
+ manifest={
+ "ImageSegmentationPredictionResult",
+ },
)
class ImageSegmentationPredictionResult(proto.Message):
r"""Prediction output format for Image Segmentation.
+
Attributes:
category_mask (str):
A PNG image where each pixel in the mask
@@ -46,8 +49,14 @@ class ImageSegmentationPredictionResult(proto.Message):
confidence and white means complete confidence.
"""
- category_mask = proto.Field(proto.STRING, number=1,)
- confidence_mask = proto.Field(proto.STRING, number=2,)
+ category_mask = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ confidence_mask = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py
index 295fd13983..1899462c99 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"TabularClassificationPredictionResult",},
+ manifest={
+ "TabularClassificationPredictionResult",
+ },
)
class TabularClassificationPredictionResult(proto.Message):
r"""Prediction output format for Tabular Classification.
+
Attributes:
classes (Sequence[str]):
The name of the classes being classified,
@@ -36,8 +39,14 @@ class TabularClassificationPredictionResult(proto.Message):
classes.
"""
- classes = proto.RepeatedField(proto.STRING, number=1,)
- scores = proto.RepeatedField(proto.FLOAT, number=2,)
+ classes = proto.RepeatedField(
+ proto.STRING,
+ number=1,
+ )
+ scores = proto.RepeatedField(
+ proto.FLOAT,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py
index 76be0023f1..41f4f657b4 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"TabularRegressionPredictionResult",},
+ manifest={
+ "TabularRegressionPredictionResult",
+ },
)
class TabularRegressionPredictionResult(proto.Message):
r"""Prediction output format for Tabular Regression.
+
Attributes:
value (float):
The regression value.
@@ -33,9 +36,18 @@ class TabularRegressionPredictionResult(proto.Message):
The upper bound of the prediction interval.
"""
- value = proto.Field(proto.FLOAT, number=1,)
- lower_bound = proto.Field(proto.FLOAT, number=2,)
- upper_bound = proto.Field(proto.FLOAT, number=3,)
+ value = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ lower_bound = proto.Field(
+ proto.FLOAT,
+ number=2,
+ )
+ upper_bound = proto.Field(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py
index 601509934a..abff499d75 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"TextExtractionPredictionResult",},
+ manifest={
+ "TextExtractionPredictionResult",
+ },
)
class TextExtractionPredictionResult(proto.Message):
r"""Prediction output format for Text Extraction.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
@@ -50,11 +53,26 @@ class TextExtractionPredictionResult(proto.Message):
confidence. Order matches the Ids.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- text_segment_start_offsets = proto.RepeatedField(proto.INT64, number=3,)
- text_segment_end_offsets = proto.RepeatedField(proto.INT64, number=4,)
- confidences = proto.RepeatedField(proto.FLOAT, number=5,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ text_segment_start_offsets = proto.RepeatedField(
+ proto.INT64,
+ number=3,
+ )
+ text_segment_end_offsets = proto.RepeatedField(
+ proto.INT64,
+ number=4,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py
index 663a40ce7c..821ebef472 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"TextSentimentPredictionResult",},
+ manifest={
+ "TextSentimentPredictionResult",
+ },
)
class TextSentimentPredictionResult(proto.Message):
r"""Prediction output format for Text Sentiment
+
Attributes:
sentiment (int):
The integer sentiment labels between 0
@@ -36,7 +39,10 @@ class TextSentimentPredictionResult(proto.Message):
(inclusive) and 10 (inclusive).
"""
- sentiment = proto.Field(proto.INT32, number=1,)
+ sentiment = proto.Field(
+ proto.INT32,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py
index c23c8b8e07..277d3e5bea 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"VideoActionRecognitionPredictionResult",},
+ manifest={
+ "VideoActionRecognitionPredictionResult",
+ },
)
class VideoActionRecognitionPredictionResult(proto.Message):
r"""Prediction output format for Video Action Recognition.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -54,15 +57,29 @@ class VideoActionRecognitionPredictionResult(proto.Message):
confidence.
"""
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=5, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=5,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=wrappers_pb2.FloatValue,
)
- confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers_pb2.FloatValue,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py
index 5edacfb81c..5686aa3571 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"VideoClassificationPredictionResult",},
+ manifest={
+ "VideoClassificationPredictionResult",
+ },
)
class VideoClassificationPredictionResult(proto.Message):
r"""Prediction output format for Video Classification.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -68,16 +71,33 @@ class VideoClassificationPredictionResult(proto.Message):
confidence.
"""
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
- type_ = proto.Field(proto.STRING, number=3,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ type_ = proto.Field(
+ proto.STRING,
+ number=3,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=5, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=5,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=wrappers_pb2.FloatValue,
)
- confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers_pb2.FloatValue,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py
index b103c70546..7651895e91 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.predict.prediction",
- manifest={"VideoObjectTrackingPredictionResult",},
+ manifest={
+ "VideoObjectTrackingPredictionResult",
+ },
)
class VideoObjectTrackingPredictionResult(proto.Message):
r"""Prediction output format for Video Object Tracking.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -84,23 +87,59 @@ class Frame(proto.Message):
"""
time_offset = proto.Field(
- proto.MESSAGE, number=1, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=1,
+ message=duration_pb2.Duration,
+ )
+ x_min = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message=wrappers_pb2.FloatValue,
+ )
+ x_max = proto.Field(
+ proto.MESSAGE,
+ number=3,
+ message=wrappers_pb2.FloatValue,
+ )
+ y_min = proto.Field(
+ proto.MESSAGE,
+ number=4,
+ message=wrappers_pb2.FloatValue,
+ )
+ y_max = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ message=wrappers_pb2.FloatValue,
)
- x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers_pb2.FloatValue,)
- x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers_pb2.FloatValue,)
- y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers_pb2.FloatValue,)
- y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers_pb2.FloatValue,)
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=3, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=3,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ message=wrappers_pb2.FloatValue,
+ )
+ frames = proto.RepeatedField(
+ proto.MESSAGE,
+ number=6,
+ message=Frame,
)
- confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers_pb2.FloatValue,)
- frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py
index bd4624d83b..7235407e1c 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py
index 16b66c2fb6..20be708c73 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py
index d70e297826..f85b4686a6 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -57,7 +57,9 @@
AutoMlVideoObjectTracking,
AutoMlVideoObjectTrackingInputs,
)
-from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig
+from .export_evaluated_data_items_config import (
+ ExportEvaluatedDataItemsConfig,
+)
__all__ = (
"AutoMlImageClassification",
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py
index d8732f8865..5660b377f4 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageClassificationInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageClassificationMetadata",
)
class AutoMlImageClassificationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageClassificationInputs.ModelType):
@@ -97,15 +102,32 @@ class ModelType(proto.Enum):
MOBILE_TF_VERSATILE_1 = 3
MOBILE_TF_HIGH_ACCURACY_1 = 4
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- base_model_id = proto.Field(proto.STRING, number=2,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=3,)
- disable_early_stopping = proto.Field(proto.BOOL, number=4,)
- multi_label = proto.Field(proto.BOOL, number=5,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ base_model_id = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=3,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=4,
+ )
+ multi_label = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
class AutoMlImageClassificationMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -124,9 +146,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py
index c9284686fd..5e7d6efddd 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageObjectDetection(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageObjectDetectionInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageObjectDetectionMetadata",
)
class AutoMlImageObjectDetectionInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageObjectDetectionInputs.ModelType):
@@ -86,13 +91,24 @@ class ModelType(proto.Enum):
MOBILE_TF_VERSATILE_1 = 4
MOBILE_TF_HIGH_ACCURACY_1 = 5
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=2,)
- disable_early_stopping = proto.Field(proto.BOOL, number=3,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
class AutoMlImageObjectDetectionMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -111,9 +127,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py
index ccd2449ccd..c9cd437303 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageSegmentation(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageSegmentationInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageSegmentationMetadata",
)
class AutoMlImageSegmentationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageSegmentationInputs.ModelType):
@@ -80,13 +85,24 @@ class ModelType(proto.Enum):
CLOUD_LOW_ACCURACY_1 = 2
MOBILE_TF_LOW_LATENCY_1 = 3
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=2,)
- base_model_id = proto.Field(proto.STRING, number=3,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+ base_model_id = proto.Field(
+ proto.STRING,
+ number=3,
+ )
class AutoMlImageSegmentationMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -105,9 +121,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py
index f05b633c87..bbcfabac75 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,12 +22,17 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",},
+ manifest={
+ "AutoMlTables",
+ "AutoMlTablesInputs",
+ "AutoMlTablesMetadata",
+ },
)
class AutoMlTables(proto.Message):
r"""A TrainingJob that trains and uploads an AutoML Tables Model.
+
Attributes:
inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs):
The input parameters of this TrainingJob.
@@ -35,21 +40,41 @@ class AutoMlTables(proto.Message):
The metadata information.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",)
- metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTablesInputs",
+ )
+ metadata = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlTablesMetadata",
+ )
class AutoMlTablesInputs(proto.Message):
r"""
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
Attributes:
optimization_objective_recall_value (float):
Required when optimization_objective is
"maximize-precision-at-recall". Must be between 0 and 1,
inclusive.
+
+ This field is a member of `oneof`_ ``additional_optimization_objective_config``.
optimization_objective_precision_value (float):
Required when optimization_objective is
"maximize-recall-at-precision". Must be between 0 and 1,
inclusive.
+
+ This field is a member of `oneof`_ ``additional_optimization_objective_config``.
prediction_type (str):
The type of prediction the Model is to
produce. "classification" - Predict one out of
@@ -86,9 +111,9 @@ class AutoMlTablesInputs(proto.Message):
operating characteristic (ROC) curve.
"minimize-log-loss" - Minimize log loss.
"maximize-au-prc" - Maximize the area under
- the precision-recall curve. "maximize-
- precision-at-recall" - Maximize precision for a
- specified
+ the precision-recall curve.
+ "maximize-precision-at-recall" - Maximize
+ precision for a specified
recall value. "maximize-recall-at-precision" -
Maximize recall for a specified
precision value.
@@ -96,11 +121,11 @@ class AutoMlTablesInputs(proto.Message):
"minimize-log-loss" (default) - Minimize log
loss.
regression:
- "minimize-rmse" (default) - Minimize root-
- mean-squared error (RMSE). "minimize-mae" -
- Minimize mean-absolute error (MAE). "minimize-
- rmsle" - Minimize root-mean-squared log error
- (RMSLE).
+ "minimize-rmse" (default) - Minimize
+ root-mean-squared error (RMSE). "minimize-mae"
+ - Minimize mean-absolute error (MAE).
+ "minimize-rmsle" - Minimize root-mean-squared
+ log error (RMSLE).
train_budget_milli_node_hours (int):
Required. The train budget of creating this
model, expressed in milli node hours i.e. 1,000
@@ -139,27 +164,46 @@ class AutoMlTablesInputs(proto.Message):
predictions to a BigQuery table. If this
configuration is absent, then the export is not
performed.
+ additional_experiments (Sequence[str]):
+ Additional experiment flags for the Tables
+ training pipeline.
"""
class Transformation(proto.Message):
r"""
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
Attributes:
auto (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.AutoTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
numeric (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.NumericTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
categorical (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.CategoricalTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
timestamp (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TimestampTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_numeric (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.NumericArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_categorical (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.CategoricalArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
"""
class AutoTransformation(proto.Message):
@@ -171,7 +215,10 @@ class AutoTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class NumericTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -197,8 +244,14 @@ class NumericTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=2,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
class CategoricalTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -216,7 +269,10 @@ class CategoricalTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class TimestampTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -253,9 +309,18 @@ class TimestampTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- time_format = proto.Field(proto.STRING, number=2,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=3,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ time_format = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
class TextTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -275,7 +340,10 @@ class TextTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class NumericArrayTransformation(proto.Message):
r"""Treats the column as numerical array and performs following
@@ -296,8 +364,14 @@ class NumericArrayTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=2,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
class CategoricalArrayTransformation(proto.Message):
r"""Treats the column as categorical array and performs following
@@ -314,7 +388,10 @@ class CategoricalArrayTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class TextArrayTransformation(proto.Message):
r"""Treats the column as text array and performs following
@@ -330,7 +407,10 @@ class TextArrayTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
auto = proto.Field(
proto.MESSAGE,
@@ -382,29 +462,58 @@ class TextArrayTransformation(proto.Message):
)
optimization_objective_recall_value = proto.Field(
- proto.FLOAT, number=5, oneof="additional_optimization_objective_config",
+ proto.FLOAT,
+ number=5,
+ oneof="additional_optimization_objective_config",
)
optimization_objective_precision_value = proto.Field(
- proto.FLOAT, number=6, oneof="additional_optimization_objective_config",
+ proto.FLOAT,
+ number=6,
+ oneof="additional_optimization_objective_config",
+ )
+ prediction_type = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ target_column = proto.Field(
+ proto.STRING,
+ number=2,
)
- prediction_type = proto.Field(proto.STRING, number=1,)
- target_column = proto.Field(proto.STRING, number=2,)
transformations = proto.RepeatedField(
- proto.MESSAGE, number=3, message=Transformation,
+ proto.MESSAGE,
+ number=3,
+ message=Transformation,
+ )
+ optimization_objective = proto.Field(
+ proto.STRING,
+ number=4,
+ )
+ train_budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=7,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=8,
+ )
+ weight_column_name = proto.Field(
+ proto.STRING,
+ number=9,
)
- optimization_objective = proto.Field(proto.STRING, number=4,)
- train_budget_milli_node_hours = proto.Field(proto.INT64, number=7,)
- disable_early_stopping = proto.Field(proto.BOOL, number=8,)
- weight_column_name = proto.Field(proto.STRING, number=9,)
export_evaluated_data_items_config = proto.Field(
proto.MESSAGE,
number=10,
message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig,
)
+ additional_experiments = proto.RepeatedField(
+ proto.STRING,
+ number=11,
+ )
class AutoMlTablesMetadata(proto.Message):
r"""Model metadata specific to AutoML Tables.
+
Attributes:
train_cost_milli_node_hours (int):
Output only. The actual training cost of the
@@ -413,7 +522,10 @@ class AutoMlTablesMetadata(proto.Message):
Guaranteed to not exceed the train budget.
"""
- train_cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ train_cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py
index 21014e1b0a..805d83c1e2 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",},
+ manifest={
+ "AutoMlTextClassification",
+ "AutoMlTextClassificationInputs",
+ },
)
@@ -32,18 +35,24 @@ class AutoMlTextClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextClassificationInputs",
)
class AutoMlTextClassificationInputs(proto.Message):
r"""
+
Attributes:
multi_label (bool):
"""
- multi_label = proto.Field(proto.BOOL, number=1,)
+ multi_label = proto.Field(
+ proto.BOOL,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py
index e475b1989b..d08f678fca 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",},
+ manifest={
+ "AutoMlTextExtraction",
+ "AutoMlTextExtractionInputs",
+ },
)
@@ -31,11 +34,15 @@ class AutoMlTextExtraction(proto.Message):
The input parameters of this TrainingJob.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextExtractionInputs",
+ )
class AutoMlTextExtractionInputs(proto.Message):
- r""" """
+ r""" """
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py
index 373ea85902..bf73cf5144 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",},
+ manifest={
+ "AutoMlTextSentiment",
+ "AutoMlTextSentimentInputs",
+ },
)
@@ -31,11 +34,16 @@ class AutoMlTextSentiment(proto.Message):
The input parameters of this TrainingJob.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextSentimentInputs",
+ )
class AutoMlTextSentimentInputs(proto.Message):
r"""
+
Attributes:
sentiment_max (int):
A sentiment is expressed as an integer
@@ -50,7 +58,10 @@ class AutoMlTextSentimentInputs(proto.Message):
between 1 and 10 (inclusive).
"""
- sentiment_max = proto.Field(proto.INT32, number=1,)
+ sentiment_max = proto.Field(
+ proto.INT32,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py
index f9eefb8c4d..8b97dfc426 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",},
+ manifest={
+ "AutoMlVideoActionRecognition",
+ "AutoMlVideoActionRecognitionInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoActionRecognition(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoActionRecognitionInputs",
)
class AutoMlVideoActionRecognitionInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoActionRecognitionInputs.ModelType):
@@ -48,8 +54,14 @@ class ModelType(proto.Enum):
MODEL_TYPE_UNSPECIFIED = 0
CLOUD = 1
MOBILE_VERSATILE_1 = 2
+ MOBILE_JETSON_VERSATILE_1 = 3
+ MOBILE_CORAL_VERSATILE_1 = 4
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py
index a0a4e88195..49b139a3ba 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",},
+ manifest={
+ "AutoMlVideoClassification",
+ "AutoMlVideoClassificationInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoClassificationInputs",
)
class AutoMlVideoClassificationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoClassificationInputs.ModelType):
@@ -50,7 +56,11 @@ class ModelType(proto.Enum):
MOBILE_VERSATILE_1 = 2
MOBILE_JETSON_VERSATILE_1 = 3
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py
index 4db3a783cf..c29fcb0161 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",},
+ manifest={
+ "AutoMlVideoObjectTracking",
+ "AutoMlVideoObjectTrackingInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoObjectTracking(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoObjectTrackingInputs",
)
class AutoMlVideoObjectTrackingInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoObjectTrackingInputs.ModelType):
@@ -53,7 +59,11 @@ class ModelType(proto.Enum):
MOBILE_JETSON_VERSATILE_1 = 5
MOBILE_JETSON_LOW_LATENCY_1 = 6
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py
index 47d910fefb..d93085ab54 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,9 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1.schema.trainingjob.definition",
- manifest={"ExportEvaluatedDataItemsConfig",},
+ manifest={
+ "ExportEvaluatedDataItemsConfig",
+ },
)
@@ -33,7 +35,6 @@ class ExportEvaluatedDataItemsConfig(proto.Message):
If not specified, then results are exported to the following
auto-created BigQuery table:
-
:export_evaluated_examples__.evaluated_examples
override_existing_table (bool):
If true and an export destination is
@@ -43,8 +44,14 @@ class ExportEvaluatedDataItemsConfig(proto.Message):
operation fails.
"""
- destination_bigquery_uri = proto.Field(proto.STRING, number=1,)
- override_existing_table = proto.Field(proto.BOOL, number=2,)
+ destination_bigquery_uri = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ override_existing_table = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py
index 4ddd6e1439..ae404a7377 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py
index fdfe1ca46f..47708ddc7f 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py
index 744852e8a3..c36f147d50 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .image_classification import ImageClassificationPredictionInstance
-from .image_object_detection import ImageObjectDetectionPredictionInstance
-from .image_segmentation import ImageSegmentationPredictionInstance
-from .text_classification import TextClassificationPredictionInstance
-from .text_extraction import TextExtractionPredictionInstance
-from .text_sentiment import TextSentimentPredictionInstance
-from .video_action_recognition import VideoActionRecognitionPredictionInstance
-from .video_classification import VideoClassificationPredictionInstance
-from .video_object_tracking import VideoObjectTrackingPredictionInstance
+from .image_classification import (
+ ImageClassificationPredictionInstance,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionInstance,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionInstance,
+)
+from .text_classification import (
+ TextClassificationPredictionInstance,
+)
+from .text_extraction import (
+ TextExtractionPredictionInstance,
+)
+from .text_sentiment import (
+ TextSentimentPredictionInstance,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionInstance,
+)
+from .video_classification import (
+ VideoClassificationPredictionInstance,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionInstance,
+)
__all__ = (
"ImageClassificationPredictionInstance",
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py
index 4c2154dd90..008a5ee3d9 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,16 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"ImageClassificationPredictionInstance",},
+ manifest={
+ "ImageClassificationPredictionInstance",
+ },
)
class ImageClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Image Classification.
+
Attributes:
content (str):
- The image bytes or GCS URI to make the
- prediction on.
+ The image bytes or Cloud Storage URI to make
+ the prediction on.
mime_type (str):
The MIME type of the content of the image.
Only the images in below listed MIME types are
@@ -40,8 +43,14 @@ class ImageClassificationPredictionInstance(proto.Message):
- image/vnd.microsoft.icon
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py
index d7b41623aa..7ab83116b4 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,16 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"ImageObjectDetectionPredictionInstance",},
+ manifest={
+ "ImageObjectDetectionPredictionInstance",
+ },
)
class ImageObjectDetectionPredictionInstance(proto.Message):
r"""Prediction input format for Image Object Detection.
+
Attributes:
content (str):
- The image bytes or GCS URI to make the
- prediction on.
+ The image bytes or Cloud Storage URI to make
+ the prediction on.
mime_type (str):
The MIME type of the content of the image.
Only the images in below listed MIME types are
@@ -40,8 +43,14 @@ class ImageObjectDetectionPredictionInstance(proto.Message):
- image/vnd.microsoft.icon
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py
index 13c96535a1..600bfdc427 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"ImageSegmentationPredictionInstance",},
+ manifest={
+ "ImageSegmentationPredictionInstance",
+ },
)
class ImageSegmentationPredictionInstance(proto.Message):
r"""Prediction input format for Image Segmentation.
+
Attributes:
content (str):
The image bytes to make the predictions on.
@@ -34,8 +37,14 @@ class ImageSegmentationPredictionInstance(proto.Message):
- image/png
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py
index 141b031701..e5e505b153 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"TextClassificationPredictionInstance",},
+ manifest={
+ "TextClassificationPredictionInstance",
+ },
)
class TextClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Text Classification.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -33,8 +36,14 @@ class TextClassificationPredictionInstance(proto.Message):
- text/plain
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py
index 9c393faa73..7ced4e00e0 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"TextExtractionPredictionInstance",},
+ manifest={
+ "TextExtractionPredictionInstance",
+ },
)
class TextExtractionPredictionInstance(proto.Message):
r"""Prediction input format for Text Extraction.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -36,15 +39,24 @@ class TextExtractionPredictionInstance(proto.Message):
If a key is provided, the batch prediction
result will by mapped to this key. If omitted,
then the batch prediction result will contain
- the entire input instance. AI Platform will not
+ the entire input instance. Vertex AI will not
check if keys in the request are duplicates, so
it is up to the caller to ensure the keys are
unique.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- key = proto.Field(proto.STRING, number=3,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ key = proto.Field(
+ proto.STRING,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py
index cc530e26b9..694e0a5dec 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"TextSentimentPredictionInstance",},
+ manifest={
+ "TextSentimentPredictionInstance",
+ },
)
class TextSentimentPredictionInstance(proto.Message):
r"""Prediction input format for Text Sentiment.
+
Attributes:
content (str):
The text snippet to make the predictions on.
@@ -33,8 +36,14 @@ class TextSentimentPredictionInstance(proto.Message):
- text/plain
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py
index 921f17b892..f85e5cf17b 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"VideoActionRecognitionPredictionInstance",},
+ manifest={
+ "VideoActionRecognitionPredictionInstance",
+ },
)
class VideoActionRecognitionPredictionInstance(proto.Message):
r"""Prediction input format for Video Action Recognition.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoActionRecognitionPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py
index f7c58db248..3d64c39aac 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"VideoClassificationPredictionInstance",},
+ manifest={
+ "VideoClassificationPredictionInstance",
+ },
)
class VideoClassificationPredictionInstance(proto.Message):
r"""Prediction input format for Video Classification.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoClassificationPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py
index 8fd28ed924..4476b8af60 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.instance",
- manifest={"VideoObjectTrackingPredictionInstance",},
+ manifest={
+ "VideoObjectTrackingPredictionInstance",
+ },
)
class VideoObjectTrackingPredictionInstance(proto.Message):
r"""Prediction input format for Video Object Tracking.
+
Attributes:
content (str):
The Google Cloud Storage location of the
@@ -49,10 +52,22 @@ class VideoObjectTrackingPredictionInstance(proto.Message):
is allowed, which means the end of the video.
"""
- content = proto.Field(proto.STRING, number=1,)
- mime_type = proto.Field(proto.STRING, number=2,)
- time_segment_start = proto.Field(proto.STRING, number=3,)
- time_segment_end = proto.Field(proto.STRING, number=4,)
+ content = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ mime_type = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_segment_start = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ time_segment_end = proto.Field(
+ proto.STRING,
+ number=4,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py
index 4a5b144b93..1f64628545 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py
index dcf74bb7a0..fd80646afd 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py
index 26997a8d81..135f3bff54 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .image_classification import ImageClassificationPredictionParams
-from .image_object_detection import ImageObjectDetectionPredictionParams
-from .image_segmentation import ImageSegmentationPredictionParams
-from .video_action_recognition import VideoActionRecognitionPredictionParams
-from .video_classification import VideoClassificationPredictionParams
-from .video_object_tracking import VideoObjectTrackingPredictionParams
+from .image_classification import (
+ ImageClassificationPredictionParams,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionParams,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionParams,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionParams,
+)
+from .video_classification import (
+ VideoClassificationPredictionParams,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionParams,
+)
__all__ = (
"ImageClassificationPredictionParams",
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py
index ada760e415..9282610720 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"ImageClassificationPredictionParams",},
+ manifest={
+ "ImageClassificationPredictionParams",
+ },
)
class ImageClassificationPredictionParams(proto.Message):
r"""Prediction model parameters for Image Classification.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -36,8 +39,14 @@ class ImageClassificationPredictionParams(proto.Message):
return fewer predictions. Default value is 10.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py
index b160fc8400..381517a2a1 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"ImageObjectDetectionPredictionParams",},
+ manifest={
+ "ImageObjectDetectionPredictionParams",
+ },
)
class ImageObjectDetectionPredictionParams(proto.Message):
r"""Prediction model parameters for Image Object Detection.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,8 +40,14 @@ class ImageObjectDetectionPredictionParams(proto.Message):
value is 10.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py
index 1c1e3cdb2e..92e69a2441 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"ImageSegmentationPredictionParams",},
+ manifest={
+ "ImageSegmentationPredictionParams",
+ },
)
class ImageSegmentationPredictionParams(proto.Message):
r"""Prediction model parameters for Image Segmentation.
+
Attributes:
confidence_threshold (float):
When the model predicts category of pixels of
@@ -33,7 +36,10 @@ class ImageSegmentationPredictionParams(proto.Message):
background. Default value is 0.5.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py
index 86afdac15f..3f6e7deca7 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"VideoActionRecognitionPredictionParams",},
+ manifest={
+ "VideoActionRecognitionPredictionParams",
+ },
)
class VideoActionRecognitionPredictionParams(proto.Message):
r"""Prediction model parameters for Video Action Recognition.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,8 +40,14 @@ class VideoActionRecognitionPredictionParams(proto.Message):
Default value is 50.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py
index 35ad2ca0ee..6edd0f2402 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"VideoClassificationPredictionParams",},
+ manifest={
+ "VideoClassificationPredictionParams",
+ },
)
class VideoClassificationPredictionParams(proto.Message):
r"""Prediction model parameters for Video Classification.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -37,16 +40,16 @@ class VideoClassificationPredictionParams(proto.Message):
10,000.
segment_classification (bool):
Set to true to request segment-level
- classification. AI Platform returns labels and
+ classification. Vertex AI returns labels and
their confidence scores for the entire time
segment of the video that user specified in the
input instance. Default value is true
shot_classification (bool):
Set to true to request shot-level
- classification. AI Platform determines the
+ classification. Vertex AI determines the
boundaries for each camera shot in the entire
time segment of the video that user specified in
- the input instance. AI Platform then returns
+ the input instance. Vertex AI then returns
labels and their confidence scores for each
detected shot, along with the start and end time
of the shot.
@@ -57,22 +60,36 @@ class VideoClassificationPredictionParams(proto.Message):
Default value is false
one_sec_interval_classification (bool):
Set to true to request classification for a
- video at one-second intervals. AI Platform
- returns labels and their confidence scores for
- each second of the entire time segment of the
- video that user specified in the input WARNING:
- Model evaluation is not done for this
- classification type, the quality of it depends
- on the training data, but there are no metrics
- provided to describe that quality. Default value
- is false
+ video at one-second intervals. Vertex AI returns
+ labels and their confidence scores for each
+ second of the entire time segment of the video
+ that user specified in the input WARNING: Model
+ evaluation is not done for this classification
+ type, the quality of it depends on the training
+ data, but there are no metrics provided to
+ describe that quality. Default value is false
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
- segment_classification = proto.Field(proto.BOOL, number=3,)
- shot_classification = proto.Field(proto.BOOL, number=4,)
- one_sec_interval_classification = proto.Field(proto.BOOL, number=5,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+ segment_classification = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
+ shot_classification = proto.Field(
+ proto.BOOL,
+ number=4,
+ )
+ one_sec_interval_classification = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py
index b4cd10b795..f81f4738b1 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.params",
- manifest={"VideoObjectTrackingPredictionParams",},
+ manifest={
+ "VideoObjectTrackingPredictionParams",
+ },
)
class VideoObjectTrackingPredictionParams(proto.Message):
r"""Prediction model parameters for Video Object Tracking.
+
Attributes:
confidence_threshold (float):
The Model only returns predictions with at
@@ -41,9 +44,18 @@ class VideoObjectTrackingPredictionParams(proto.Message):
frame size are returned. Default value is 0.0.
"""
- confidence_threshold = proto.Field(proto.FLOAT, number=1,)
- max_predictions = proto.Field(proto.INT32, number=2,)
- min_bounding_box_size = proto.Field(proto.FLOAT, number=3,)
+ confidence_threshold = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ max_predictions = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+ min_bounding_box_size = proto.Field(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py
index df5f78f60c..dcb271f09c 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,6 +36,9 @@
from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import (
TextSentimentPredictionResult,
)
+from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.time_series_forecasting import (
+ TimeSeriesForecastingPredictionResult,
+)
from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import (
VideoActionRecognitionPredictionResult,
)
@@ -54,6 +57,7 @@
"TabularRegressionPredictionResult",
"TextExtractionPredictionResult",
"TextSentimentPredictionResult",
+ "TimeSeriesForecastingPredictionResult",
"VideoActionRecognitionPredictionResult",
"VideoClassificationPredictionResult",
"VideoObjectTrackingPredictionResult",
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py
index 866cade4d0..be90457ec5 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
from .types.tabular_regression import TabularRegressionPredictionResult
from .types.text_extraction import TextExtractionPredictionResult
from .types.text_sentiment import TextSentimentPredictionResult
+from .types.time_series_forecasting import TimeSeriesForecastingPredictionResult
from .types.video_action_recognition import VideoActionRecognitionPredictionResult
from .types.video_classification import VideoClassificationPredictionResult
from .types.video_object_tracking import VideoObjectTrackingPredictionResult
@@ -34,6 +35,7 @@
"TabularRegressionPredictionResult",
"TextExtractionPredictionResult",
"TextSentimentPredictionResult",
+ "TimeSeriesForecastingPredictionResult",
"VideoActionRecognitionPredictionResult",
"VideoClassificationPredictionResult",
"VideoObjectTrackingPredictionResult",
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py
index 0bb99636b3..582c0bbe12 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from .classification import ClassificationPredictionResult
-from .image_object_detection import ImageObjectDetectionPredictionResult
-from .image_segmentation import ImageSegmentationPredictionResult
-from .tabular_classification import TabularClassificationPredictionResult
-from .tabular_regression import TabularRegressionPredictionResult
-from .text_extraction import TextExtractionPredictionResult
-from .text_sentiment import TextSentimentPredictionResult
-from .video_action_recognition import VideoActionRecognitionPredictionResult
-from .video_classification import VideoClassificationPredictionResult
-from .video_object_tracking import VideoObjectTrackingPredictionResult
+from .classification import (
+ ClassificationPredictionResult,
+)
+from .image_object_detection import (
+ ImageObjectDetectionPredictionResult,
+)
+from .image_segmentation import (
+ ImageSegmentationPredictionResult,
+)
+from .tabular_classification import (
+ TabularClassificationPredictionResult,
+)
+from .tabular_regression import (
+ TabularRegressionPredictionResult,
+)
+from .text_extraction import (
+ TextExtractionPredictionResult,
+)
+from .text_sentiment import (
+ TextSentimentPredictionResult,
+)
+from .time_series_forecasting import (
+ TimeSeriesForecastingPredictionResult,
+)
+from .video_action_recognition import (
+ VideoActionRecognitionPredictionResult,
+)
+from .video_classification import (
+ VideoClassificationPredictionResult,
+)
+from .video_object_tracking import (
+ VideoObjectTrackingPredictionResult,
+)
__all__ = (
"ClassificationPredictionResult",
@@ -32,6 +55,7 @@
"TabularRegressionPredictionResult",
"TextExtractionPredictionResult",
"TextSentimentPredictionResult",
+ "TimeSeriesForecastingPredictionResult",
"VideoActionRecognitionPredictionResult",
"VideoClassificationPredictionResult",
"VideoObjectTrackingPredictionResult",
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py
index d37236a5cc..ce41c4e442 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +18,19 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"ClassificationPredictionResult",},
+ manifest={
+ "ClassificationPredictionResult",
+ },
)
class ClassificationPredictionResult(proto.Message):
r"""Prediction output format for Image and Text Classification.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
- had been identified, ordered by the confidence
- score descendingly.
+ had been identified.
display_names (Sequence[str]):
The display names of the AnnotationSpecs that
had been identified, order matches the IDs.
@@ -38,9 +40,18 @@ class ClassificationPredictionResult(proto.Message):
confidence. Order matches the Ids.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- confidences = proto.RepeatedField(proto.FLOAT, number=3,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py
index e1ed4f5c1e..069e92b2b6 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,12 +20,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"ImageObjectDetectionPredictionResult",},
+ manifest={
+ "ImageObjectDetectionPredictionResult",
+ },
)
class ImageObjectDetectionPredictionResult(proto.Message):
r"""Prediction output format for Image Object Detection.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
@@ -48,10 +51,23 @@ class ImageObjectDetectionPredictionResult(proto.Message):
image.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- confidences = proto.RepeatedField(proto.FLOAT, number=3,)
- bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct_pb2.ListValue,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=3,
+ )
+ bboxes = proto.RepeatedField(
+ proto.MESSAGE,
+ number=4,
+ message=struct_pb2.ListValue,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py
index 538de9f561..7aa51cf2a0 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"ImageSegmentationPredictionResult",},
+ manifest={
+ "ImageSegmentationPredictionResult",
+ },
)
class ImageSegmentationPredictionResult(proto.Message):
r"""Prediction output format for Image Segmentation.
+
Attributes:
category_mask (str):
A PNG image where each pixel in the mask
@@ -46,8 +49,14 @@ class ImageSegmentationPredictionResult(proto.Message):
confidence and white means complete confidence.
"""
- category_mask = proto.Field(proto.STRING, number=1,)
- confidence_mask = proto.Field(proto.STRING, number=2,)
+ category_mask = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ confidence_mask = proto.Field(
+ proto.STRING,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py
index e6673fe360..019c2d74b1 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"TabularClassificationPredictionResult",},
+ manifest={
+ "TabularClassificationPredictionResult",
+ },
)
class TabularClassificationPredictionResult(proto.Message):
r"""Prediction output format for Tabular Classification.
+
Attributes:
classes (Sequence[str]):
The name of the classes being classified,
@@ -36,8 +39,14 @@ class TabularClassificationPredictionResult(proto.Message):
classes.
"""
- classes = proto.RepeatedField(proto.STRING, number=1,)
- scores = proto.RepeatedField(proto.FLOAT, number=2,)
+ classes = proto.RepeatedField(
+ proto.STRING,
+ number=1,
+ )
+ scores = proto.RepeatedField(
+ proto.FLOAT,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py
index f8273be054..ffa086d406 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/tabular_regression.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"TabularRegressionPredictionResult",},
+ manifest={
+ "TabularRegressionPredictionResult",
+ },
)
class TabularRegressionPredictionResult(proto.Message):
r"""Prediction output format for Tabular Regression.
+
Attributes:
value (float):
The regression value.
@@ -33,9 +36,18 @@ class TabularRegressionPredictionResult(proto.Message):
The upper bound of the prediction interval.
"""
- value = proto.Field(proto.FLOAT, number=1,)
- lower_bound = proto.Field(proto.FLOAT, number=2,)
- upper_bound = proto.Field(proto.FLOAT, number=3,)
+ value = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
+ lower_bound = proto.Field(
+ proto.FLOAT,
+ number=2,
+ )
+ upper_bound = proto.Field(
+ proto.FLOAT,
+ number=3,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py
index 1c70ab440b..8f9ca9cf7f 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"TextExtractionPredictionResult",},
+ manifest={
+ "TextExtractionPredictionResult",
+ },
)
class TextExtractionPredictionResult(proto.Message):
r"""Prediction output format for Text Extraction.
+
Attributes:
ids (Sequence[int]):
The resource IDs of the AnnotationSpecs that
@@ -50,11 +53,26 @@ class TextExtractionPredictionResult(proto.Message):
confidence. Order matches the Ids.
"""
- ids = proto.RepeatedField(proto.INT64, number=1,)
- display_names = proto.RepeatedField(proto.STRING, number=2,)
- text_segment_start_offsets = proto.RepeatedField(proto.INT64, number=3,)
- text_segment_end_offsets = proto.RepeatedField(proto.INT64, number=4,)
- confidences = proto.RepeatedField(proto.FLOAT, number=5,)
+ ids = proto.RepeatedField(
+ proto.INT64,
+ number=1,
+ )
+ display_names = proto.RepeatedField(
+ proto.STRING,
+ number=2,
+ )
+ text_segment_start_offsets = proto.RepeatedField(
+ proto.INT64,
+ number=3,
+ )
+ text_segment_end_offsets = proto.RepeatedField(
+ proto.INT64,
+ number=4,
+ )
+ confidences = proto.RepeatedField(
+ proto.FLOAT,
+ number=5,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py
index 76ac7392aa..a897d06c2f 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"TextSentimentPredictionResult",},
+ manifest={
+ "TextSentimentPredictionResult",
+ },
)
class TextSentimentPredictionResult(proto.Message):
r"""Prediction output format for Text Sentiment
+
Attributes:
sentiment (int):
The integer sentiment labels between 0
@@ -36,7 +39,10 @@ class TextSentimentPredictionResult(proto.Message):
(inclusive) and 10 (inclusive).
"""
- sentiment = proto.Field(proto.INT32, number=1,)
+ sentiment = proto.Field(
+ proto.INT32,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/time_series_forecasting.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/time_series_forecasting.py
index 38bd8e3c85..eccda58c18 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/time_series_forecasting.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/time_series_forecasting.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
import proto # type: ignore
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"TimeSeriesForecastingPredictionResult",},
+ manifest={
+ "TimeSeriesForecastingPredictionResult",
+ },
)
@@ -30,17 +30,12 @@ class TimeSeriesForecastingPredictionResult(proto.Message):
Attributes:
value (float):
The regression value.
- lower_bound (float):
- The lower bound of the prediction interval.
- upper_bound (float):
- The upper bound of the prediction interval.
"""
- value = proto.Field(proto.FLOAT, number=1)
-
- lower_bound = proto.Field(proto.FLOAT, number=2)
-
- upper_bound = proto.Field(proto.FLOAT, number=3)
+ value = proto.Field(
+ proto.FLOAT,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py
index b33184277e..b435c3c0af 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"VideoActionRecognitionPredictionResult",},
+ manifest={
+ "VideoActionRecognitionPredictionResult",
+ },
)
class VideoActionRecognitionPredictionResult(proto.Message):
r"""Prediction output format for Video Action Recognition.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -54,15 +57,29 @@ class VideoActionRecognitionPredictionResult(proto.Message):
confidence.
"""
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=5, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=5,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=wrappers_pb2.FloatValue,
)
- confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers_pb2.FloatValue,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py
index 3d4abadd6a..38bb28ea11 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"VideoClassificationPredictionResult",},
+ manifest={
+ "VideoClassificationPredictionResult",
+ },
)
class VideoClassificationPredictionResult(proto.Message):
r"""Prediction output format for Video Classification.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -68,16 +71,33 @@ class VideoClassificationPredictionResult(proto.Message):
confidence.
"""
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
- type_ = proto.Field(proto.STRING, number=3,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ type_ = proto.Field(
+ proto.STRING,
+ number=3,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=5, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=5,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message=wrappers_pb2.FloatValue,
)
- confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers_pb2.FloatValue,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py
index 9b085f2309..690952e808 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,12 +21,15 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.predict.prediction",
- manifest={"VideoObjectTrackingPredictionResult",},
+ manifest={
+ "VideoObjectTrackingPredictionResult",
+ },
)
class VideoObjectTrackingPredictionResult(proto.Message):
r"""Prediction output format for Video Object Tracking.
+
Attributes:
id (str):
The resource ID of the AnnotationSpec that
@@ -84,23 +87,59 @@ class Frame(proto.Message):
"""
time_offset = proto.Field(
- proto.MESSAGE, number=1, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=1,
+ message=duration_pb2.Duration,
+ )
+ x_min = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message=wrappers_pb2.FloatValue,
+ )
+ x_max = proto.Field(
+ proto.MESSAGE,
+ number=3,
+ message=wrappers_pb2.FloatValue,
+ )
+ y_min = proto.Field(
+ proto.MESSAGE,
+ number=4,
+ message=wrappers_pb2.FloatValue,
+ )
+ y_max = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ message=wrappers_pb2.FloatValue,
)
- x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers_pb2.FloatValue,)
- x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers_pb2.FloatValue,)
- y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers_pb2.FloatValue,)
- y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers_pb2.FloatValue,)
- id = proto.Field(proto.STRING, number=1,)
- display_name = proto.Field(proto.STRING, number=2,)
+ id = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ display_name = proto.Field(
+ proto.STRING,
+ number=2,
+ )
time_segment_start = proto.Field(
- proto.MESSAGE, number=3, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=3,
+ message=duration_pb2.Duration,
)
time_segment_end = proto.Field(
- proto.MESSAGE, number=4, message=duration_pb2.Duration,
+ proto.MESSAGE,
+ number=4,
+ message=duration_pb2.Duration,
+ )
+ confidence = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ message=wrappers_pb2.FloatValue,
+ )
+ frames = proto.RepeatedField(
+ proto.MESSAGE,
+ number=6,
+ message=Frame,
)
- confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers_pb2.FloatValue,)
- frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,)
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py
index eae6c5d2fa..3e9dbeaae4 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -69,6 +69,15 @@
from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_text_sentiment import (
AutoMlTextSentimentInputs,
)
+from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_time_series_forecasting import (
+ AutoMlForecasting,
+)
+from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_time_series_forecasting import (
+ AutoMlForecastingInputs,
+)
+from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_time_series_forecasting import (
+ AutoMlForecastingMetadata,
+)
from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_video_action_recognition import (
AutoMlVideoActionRecognition,
)
@@ -110,6 +119,9 @@
"AutoMlTextExtractionInputs",
"AutoMlTextSentiment",
"AutoMlTextSentimentInputs",
+ "AutoMlForecasting",
+ "AutoMlForecastingInputs",
+ "AutoMlForecastingMetadata",
"AutoMlVideoActionRecognition",
"AutoMlVideoActionRecognitionInputs",
"AutoMlVideoClassification",
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py
index 16b66c2fb6..c35353f39c 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +33,9 @@
from .types.automl_text_extraction import AutoMlTextExtractionInputs
from .types.automl_text_sentiment import AutoMlTextSentiment
from .types.automl_text_sentiment import AutoMlTextSentimentInputs
+from .types.automl_time_series_forecasting import AutoMlForecasting
+from .types.automl_time_series_forecasting import AutoMlForecastingInputs
+from .types.automl_time_series_forecasting import AutoMlForecastingMetadata
from .types.automl_video_action_recognition import AutoMlVideoActionRecognition
from .types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs
from .types.automl_video_classification import AutoMlVideoClassification
@@ -42,6 +45,9 @@
from .types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig
__all__ = (
+ "AutoMlForecasting",
+ "AutoMlForecastingInputs",
+ "AutoMlForecastingMetadata",
"AutoMlImageClassification",
"AutoMlImageClassificationInputs",
"AutoMlImageClassificationMetadata",
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py
index d70e297826..7de288bc76 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -45,6 +45,11 @@
AutoMlTextSentiment,
AutoMlTextSentimentInputs,
)
+from .automl_time_series_forecasting import (
+ AutoMlForecasting,
+ AutoMlForecastingInputs,
+ AutoMlForecastingMetadata,
+)
from .automl_video_action_recognition import (
AutoMlVideoActionRecognition,
AutoMlVideoActionRecognitionInputs,
@@ -57,7 +62,9 @@
AutoMlVideoObjectTracking,
AutoMlVideoObjectTrackingInputs,
)
-from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig
+from .export_evaluated_data_items_config import (
+ ExportEvaluatedDataItemsConfig,
+)
__all__ = (
"AutoMlImageClassification",
@@ -78,6 +85,9 @@
"AutoMlTextExtractionInputs",
"AutoMlTextSentiment",
"AutoMlTextSentimentInputs",
+ "AutoMlForecasting",
+ "AutoMlForecastingInputs",
+ "AutoMlForecastingMetadata",
"AutoMlVideoActionRecognition",
"AutoMlVideoActionRecognitionInputs",
"AutoMlVideoClassification",
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py
index 34f700f8af..3571368e5f 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py
@@ -44,10 +44,16 @@ class AutoMlForecasting(proto.Message):
The metadata information.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlForecastingInputs",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlForecastingInputs",
+ )
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlForecastingMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlForecastingMetadata",
)
@@ -439,7 +445,9 @@ class Period(proto.Message):
time_column = proto.Field(proto.STRING, number=3)
transformations = proto.RepeatedField(
- proto.MESSAGE, number=4, message=Transformation,
+ proto.MESSAGE,
+ number=4,
+ message=Transformation,
)
optimization_objective = proto.Field(proto.STRING, number=5)
@@ -454,7 +462,11 @@ class Period(proto.Message):
time_variant_past_and_future_columns = proto.RepeatedField(proto.STRING, number=10)
- period = proto.Field(proto.MESSAGE, number=11, message=Period,)
+ period = proto.Field(
+ proto.MESSAGE,
+ number=11,
+ message=Period,
+ )
forecast_window_start = proto.Field(proto.INT64, number=12)
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py
index 945962bb50..4c2e019747 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageClassificationInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageClassificationMetadata",
)
class AutoMlImageClassificationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageClassificationInputs.ModelType):
@@ -97,15 +102,32 @@ class ModelType(proto.Enum):
MOBILE_TF_VERSATILE_1 = 3
MOBILE_TF_HIGH_ACCURACY_1 = 4
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- base_model_id = proto.Field(proto.STRING, number=2,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=3,)
- disable_early_stopping = proto.Field(proto.BOOL, number=4,)
- multi_label = proto.Field(proto.BOOL, number=5,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ base_model_id = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=3,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=4,
+ )
+ multi_label = proto.Field(
+ proto.BOOL,
+ number=5,
+ )
class AutoMlImageClassificationMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -124,9 +146,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py
index 1d95b93970..1fa70c468e 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageObjectDetection(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageObjectDetectionInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageObjectDetectionMetadata",
)
class AutoMlImageObjectDetectionInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageObjectDetectionInputs.ModelType):
@@ -86,13 +91,24 @@ class ModelType(proto.Enum):
MOBILE_TF_VERSATILE_1 = 4
MOBILE_TF_HIGH_ACCURACY_1 = 5
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=2,)
- disable_early_stopping = proto.Field(proto.BOOL, number=3,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
class AutoMlImageObjectDetectionMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -111,9 +127,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py
index 4b47874f37..5cdf7b69b5 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,15 +38,20 @@ class AutoMlImageSegmentation(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlImageSegmentationInputs",
)
metadata = proto.Field(
- proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata",
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlImageSegmentationMetadata",
)
class AutoMlImageSegmentationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageSegmentationInputs.ModelType):
@@ -80,13 +85,24 @@ class ModelType(proto.Enum):
CLOUD_LOW_ACCURACY_1 = 2
MOBILE_TF_LOW_LATENCY_1 = 3
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
- budget_milli_node_hours = proto.Field(proto.INT64, number=2,)
- base_model_id = proto.Field(proto.STRING, number=3,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
+ budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+ base_model_id = proto.Field(
+ proto.STRING,
+ number=3,
+ )
class AutoMlImageSegmentationMetadata(proto.Message):
r"""
+
Attributes:
cost_milli_node_hours (int):
The actual training cost of creating this
@@ -105,9 +121,14 @@ class SuccessfulStopReason(proto.Enum):
BUDGET_REACHED = 1
MODEL_CONVERGED = 2
- cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
successful_stop_reason = proto.Field(
- proto.ENUM, number=2, enum=SuccessfulStopReason,
+ proto.ENUM,
+ number=2,
+ enum=SuccessfulStopReason,
)
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py
index 3531ec74f6..f1b1bf9aba 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,12 +22,17 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",},
+ manifest={
+ "AutoMlTables",
+ "AutoMlTablesInputs",
+ "AutoMlTablesMetadata",
+ },
)
class AutoMlTables(proto.Message):
r"""A TrainingJob that trains and uploads an AutoML Tables Model.
+
Attributes:
inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs):
The input parameters of this TrainingJob.
@@ -35,21 +40,41 @@ class AutoMlTables(proto.Message):
The metadata information.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",)
- metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTablesInputs",
+ )
+ metadata = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlTablesMetadata",
+ )
class AutoMlTablesInputs(proto.Message):
r"""
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
Attributes:
optimization_objective_recall_value (float):
Required when optimization_objective is
"maximize-precision-at-recall". Must be between 0 and 1,
inclusive.
+
+ This field is a member of `oneof`_ ``additional_optimization_objective_config``.
optimization_objective_precision_value (float):
Required when optimization_objective is
"maximize-recall-at-precision". Must be between 0 and 1,
inclusive.
+
+ This field is a member of `oneof`_ ``additional_optimization_objective_config``.
prediction_type (str):
The type of prediction the Model is to
produce. "classification" - Predict one out of
@@ -86,9 +111,9 @@ class AutoMlTablesInputs(proto.Message):
operating characteristic (ROC) curve.
"minimize-log-loss" - Minimize log loss.
"maximize-au-prc" - Maximize the area under
- the precision-recall curve. "maximize-
- precision-at-recall" - Maximize precision for a
- specified
+ the precision-recall curve.
+ "maximize-precision-at-recall" - Maximize
+ precision for a specified
recall value. "maximize-recall-at-precision" -
Maximize recall for a specified
precision value.
@@ -96,11 +121,11 @@ class AutoMlTablesInputs(proto.Message):
"minimize-log-loss" (default) - Minimize log
loss.
regression:
- "minimize-rmse" (default) - Minimize root-
- mean-squared error (RMSE). "minimize-mae" -
- Minimize mean-absolute error (MAE). "minimize-
- rmsle" - Minimize root-mean-squared log error
- (RMSLE).
+ "minimize-rmse" (default) - Minimize
+ root-mean-squared error (RMSE). "minimize-mae"
+ - Minimize mean-absolute error (MAE).
+ "minimize-rmsle" - Minimize root-mean-squared
+ log error (RMSLE).
train_budget_milli_node_hours (int):
Required. The train budget of creating this
model, expressed in milli node hours i.e. 1,000
@@ -139,27 +164,46 @@ class AutoMlTablesInputs(proto.Message):
predictions to a BigQuery table. If this
configuration is absent, then the export is not
performed.
+ additional_experiments (Sequence[str]):
+ Additional experiment flags for the Tables
+ training pipeline.
"""
class Transformation(proto.Message):
r"""
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
Attributes:
auto (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.AutoTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.NumericTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.CategoricalTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
timestamp (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TimestampTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.NumericArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.CategoricalArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
repeated_text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation):
+ This field is a member of `oneof`_ ``transformation_detail``.
"""
class AutoTransformation(proto.Message):
@@ -171,7 +215,10 @@ class AutoTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class NumericTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -197,8 +244,14 @@ class NumericTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=2,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
class CategoricalTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -216,7 +269,10 @@ class CategoricalTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class TimestampTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -253,9 +309,18 @@ class TimestampTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- time_format = proto.Field(proto.STRING, number=2,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=3,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ time_format = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=3,
+ )
class TextTransformation(proto.Message):
r"""Training pipeline will perform following transformation functions.
@@ -275,7 +340,10 @@ class TextTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class NumericArrayTransformation(proto.Message):
r"""Treats the column as numerical array and performs following
@@ -296,8 +364,14 @@ class NumericArrayTransformation(proto.Message):
from trainining data.
"""
- column_name = proto.Field(proto.STRING, number=1,)
- invalid_values_allowed = proto.Field(proto.BOOL, number=2,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ invalid_values_allowed = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
class CategoricalArrayTransformation(proto.Message):
r"""Treats the column as categorical array and performs following
@@ -314,7 +388,10 @@ class CategoricalArrayTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
class TextArrayTransformation(proto.Message):
r"""Treats the column as text array and performs following
@@ -330,7 +407,10 @@ class TextArrayTransformation(proto.Message):
"""
- column_name = proto.Field(proto.STRING, number=1,)
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
auto = proto.Field(
proto.MESSAGE,
@@ -382,29 +462,58 @@ class TextArrayTransformation(proto.Message):
)
optimization_objective_recall_value = proto.Field(
- proto.FLOAT, number=5, oneof="additional_optimization_objective_config",
+ proto.FLOAT,
+ number=5,
+ oneof="additional_optimization_objective_config",
)
optimization_objective_precision_value = proto.Field(
- proto.FLOAT, number=6, oneof="additional_optimization_objective_config",
+ proto.FLOAT,
+ number=6,
+ oneof="additional_optimization_objective_config",
+ )
+ prediction_type = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ target_column = proto.Field(
+ proto.STRING,
+ number=2,
)
- prediction_type = proto.Field(proto.STRING, number=1,)
- target_column = proto.Field(proto.STRING, number=2,)
transformations = proto.RepeatedField(
- proto.MESSAGE, number=3, message=Transformation,
+ proto.MESSAGE,
+ number=3,
+ message=Transformation,
+ )
+ optimization_objective = proto.Field(
+ proto.STRING,
+ number=4,
+ )
+ train_budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=7,
+ )
+ disable_early_stopping = proto.Field(
+ proto.BOOL,
+ number=8,
+ )
+ weight_column_name = proto.Field(
+ proto.STRING,
+ number=9,
)
- optimization_objective = proto.Field(proto.STRING, number=4,)
- train_budget_milli_node_hours = proto.Field(proto.INT64, number=7,)
- disable_early_stopping = proto.Field(proto.BOOL, number=8,)
- weight_column_name = proto.Field(proto.STRING, number=9,)
export_evaluated_data_items_config = proto.Field(
proto.MESSAGE,
number=10,
message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig,
)
+ additional_experiments = proto.RepeatedField(
+ proto.STRING,
+ number=11,
+ )
class AutoMlTablesMetadata(proto.Message):
r"""Model metadata specific to AutoML Tables.
+
Attributes:
train_cost_milli_node_hours (int):
Output only. The actual training cost of the
@@ -413,7 +522,10 @@ class AutoMlTablesMetadata(proto.Message):
Guaranteed to not exceed the train budget.
"""
- train_cost_milli_node_hours = proto.Field(proto.INT64, number=1,)
+ train_cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py
index bd52a0e808..ca1a163c1a 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",},
+ manifest={
+ "AutoMlTextClassification",
+ "AutoMlTextClassificationInputs",
+ },
)
@@ -32,18 +35,24 @@ class AutoMlTextClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextClassificationInputs",
)
class AutoMlTextClassificationInputs(proto.Message):
r"""
+
Attributes:
multi_label (bool):
"""
- multi_label = proto.Field(proto.BOOL, number=1,)
+ multi_label = proto.Field(
+ proto.BOOL,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py
index ba838e0ccc..89ae0aeeb9 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",},
+ manifest={
+ "AutoMlTextExtraction",
+ "AutoMlTextExtractionInputs",
+ },
)
@@ -31,11 +34,15 @@ class AutoMlTextExtraction(proto.Message):
The input parameters of this TrainingJob.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextExtractionInputs",
+ )
class AutoMlTextExtractionInputs(proto.Message):
- r""" """
+ r""" """
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py
index 4439db4bcc..9d9b6d9e31 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",},
+ manifest={
+ "AutoMlTextSentiment",
+ "AutoMlTextSentimentInputs",
+ },
)
@@ -31,11 +34,16 @@ class AutoMlTextSentiment(proto.Message):
The input parameters of this TrainingJob.
"""
- inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",)
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlTextSentimentInputs",
+ )
class AutoMlTextSentimentInputs(proto.Message):
r"""
+
Attributes:
sentiment_max (int):
A sentiment is expressed as an integer
@@ -50,7 +58,10 @@ class AutoMlTextSentimentInputs(proto.Message):
between 1 and 10 (inclusive).
"""
- sentiment_max = proto.Field(proto.INT32, number=1,)
+ sentiment_max = proto.Field(
+ proto.INT32,
+ number=1,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_time_series_forecasting.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_time_series_forecasting.py
new file mode 100644
index 0000000000..3f1e03c914
--- /dev/null
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_time_series_forecasting.py
@@ -0,0 +1,495 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import proto # type: ignore
+
+from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types import (
+ export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config,
+)
+
+
+__protobuf__ = proto.module(
+ package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
+ manifest={
+ "AutoMlForecasting",
+ "AutoMlForecastingInputs",
+ "AutoMlForecastingMetadata",
+ },
+)
+
+
+class AutoMlForecasting(proto.Message):
+ r"""A TrainingJob that trains and uploads an AutoML Forecasting
+ Model.
+
+ Attributes:
+ inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs):
+ The input parameters of this TrainingJob.
+ metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingMetadata):
+ The metadata information.
+ """
+
+ inputs = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlForecastingInputs",
+ )
+ metadata = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ message="AutoMlForecastingMetadata",
+ )
+
+
+class AutoMlForecastingInputs(proto.Message):
+ r"""
+
+ Attributes:
+ target_column (str):
+ The name of the column that the model is to
+ predict.
+ time_series_identifier_column (str):
+ The name of the column that identifies the
+ time series.
+ time_column (str):
+ The name of the column that identifies time
+ order in the time series.
+ transformations (Sequence[google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation]):
+ Each transformation will apply transform
+ function to given input column. And the result
+ will be used for training. When creating
+ transformation for BigQuery Struct column, the
+ column should be flattened using "." as the
+ delimiter.
+ optimization_objective (str):
+ Objective function the model is optimizing towards. The
+ training process creates a model that optimizes the value of
+ the objective function over the validation set.
+
+ The supported optimization objectives:
+
+ - "minimize-rmse" (default) - Minimize root-mean-squared
+ error (RMSE).
+
+ - "minimize-mae" - Minimize mean-absolute error (MAE).
+
+ - "minimize-rmsle" - Minimize root-mean-squared log error
+ (RMSLE).
+
+ - "minimize-rmspe" - Minimize root-mean-squared percentage
+ error (RMSPE).
+
+ - "minimize-wape-mae" - Minimize the combination of
+ weighted absolute percentage error (WAPE) and
+ mean-absolute-error (MAE).
+
+ - "minimize-quantile-loss" - Minimize the quantile loss at
+ the quantiles defined in ``quantiles``.
+ train_budget_milli_node_hours (int):
+ Required. The train budget of creating this
+ model, expressed in milli node hours i.e. 1,000
+ value in this field means 1 node hour.
+ The training cost of the model will not exceed
+ this budget. The final cost will be attempted to
+ be close to the budget, though may end up being
+ (even) noticeably smaller - at the backend's
+ discretion. This especially may happen when
+ further model training ceases to provide any
+ improvements.
+ If the budget is set to a value known to be
+ insufficient to train a model for the given
+ dataset, the training won't be attempted and
+ will error.
+
+ The train budget must be between 1,000 and
+ 72,000 milli node hours, inclusive.
+ weight_column (str):
+ Column name that should be used as the weight
+ column. Higher values in this column give more
+ importance to the row during model training. The
+ column must have numeric values between 0 and
+ 10000 inclusively; 0 means the row is ignored
+ for training. If weight column field is not set,
+ then all rows are assumed to have equal weight
+ of 1.
+ time_series_attribute_columns (Sequence[str]):
+ Column names that should be used as attribute
+ columns. The value of these columns does not
+ vary as a function of time. For example, store
+ ID or item color.
+ unavailable_at_forecast_columns (Sequence[str]):
+ Names of columns that are unavailable when a forecast is
+ requested. This column contains information for the given
+ entity (identified by the time_series_identifier_column)
+ that is unknown before the forecast For example, actual
+ weather on a given day.
+ available_at_forecast_columns (Sequence[str]):
+ Names of columns that are available and provided when a
+ forecast is requested. These columns contain information for
+ the given entity (identified by the
+ time_series_identifier_column column) that is known at
+ forecast. For example, predicted weather for a specific day.
+ data_granularity (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Granularity):
+ Expected difference in time granularity
+ between rows in the data.
+ forecast_horizon (int):
+ The amount of time into the future for which forecasted
+ values for the target are returned. Expressed in number of
+ units defined by the ``data_granularity`` field.
+ context_window (int):
+ The amount of time into the past training and prediction
+ data is used for model training and prediction respectively.
+ Expressed in number of units defined by the
+ ``data_granularity`` field.
+ export_evaluated_data_items_config (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.ExportEvaluatedDataItemsConfig):
+ Configuration for exporting test set
+ predictions to a BigQuery table. If this
+ configuration is absent, then the export is not
+ performed.
+ quantiles (Sequence[float]):
+ Quantiles to use for minimize-quantile-loss
+ ``optimization_objective``. Up to 5 quantiles are allowed of
+ values between 0 and 1, exclusive. Required if the value of
+ optimization_objective is minimize-quantile-loss. Represents
+ the percent quantiles to use for that objective. Quantiles
+ must be unique.
+ validation_options (str):
+ Validation options for the data validation component. The
+ available options are:
+
+ - "fail-pipeline" - default, will validate against the
+ validation and fail the pipeline if it fails.
+
+ - "ignore-validation" - ignore the results of the
+ validation and continue
+ additional_experiments (Sequence[str]):
+ Additional experiment flags for the time
+ series forcasting training.
+ """
+
+ class Transformation(proto.Message):
+ r"""
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ auto (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.AutoTransformation):
+
+ This field is a member of `oneof`_ ``transformation_detail``.
+ numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.NumericTransformation):
+
+ This field is a member of `oneof`_ ``transformation_detail``.
+ categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.CategoricalTransformation):
+
+ This field is a member of `oneof`_ ``transformation_detail``.
+ timestamp (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.TimestampTransformation):
+
+ This field is a member of `oneof`_ ``transformation_detail``.
+ text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.TextTransformation):
+
+ This field is a member of `oneof`_ ``transformation_detail``.
+ """
+
+ class AutoTransformation(proto.Message):
+ r"""Training pipeline will infer the proper transformation based
+ on the statistic of dataset.
+
+ Attributes:
+ column_name (str):
+
+ """
+
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+ class NumericTransformation(proto.Message):
+ r"""Training pipeline will perform following transformation functions.
+
+ - The value converted to float32.
+
+ - The z_score of the value.
+
+ - log(value+1) when the value is greater than or equal to 0.
+ Otherwise, this transformation is not applied and the value is
+ considered a missing value.
+
+ - z_score of log(value+1) when the value is greater than or equal
+ to 0. Otherwise, this transformation is not applied and the value
+ is considered a missing value.
+
+ - A boolean value that indicates whether the value is valid.
+
+ Attributes:
+ column_name (str):
+
+ """
+
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+ class CategoricalTransformation(proto.Message):
+ r"""Training pipeline will perform following transformation functions.
+
+ - The categorical string as is--no change to case, punctuation,
+ spelling, tense, and so on.
+
+ - Convert the category name to a dictionary lookup index and
+ generate an embedding for each index.
+
+ - Categories that appear less than 5 times in the training dataset
+ are treated as the "unknown" category. The "unknown" category
+ gets its own special lookup index and resulting embedding.
+
+ Attributes:
+ column_name (str):
+
+ """
+
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+ class TimestampTransformation(proto.Message):
+ r"""Training pipeline will perform following transformation functions.
+
+ - Apply the transformation functions for Numerical columns.
+
+ - Determine the year, month, day,and weekday. Treat each value from
+ the timestamp as a Categorical column.
+
+ - Invalid numerical values (for example, values that fall outside
+ of a typical timestamp range, or are extreme values) receive no
+ special treatment and are not removed.
+
+ Attributes:
+ column_name (str):
+
+ time_format (str):
+ The format in which that time field is expressed. The
+ time_format must either be one of:
+
+ - ``unix-seconds``
+
+ - ``unix-milliseconds``
+
+ - ``unix-microseconds``
+
+ - ``unix-nanoseconds``
+
+ (for respectively number of seconds, milliseconds,
+ microseconds and nanoseconds since start of the Unix epoch);
+
+ or be written in ``strftime`` syntax.
+
+ If time_format is not set, then the default format is RFC
+ 3339 ``date-time`` format, where ``time-offset`` = ``"Z"``
+ (e.g. 1985-04-12T23:20:50.52Z)
+ """
+
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ time_format = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+
+ class TextTransformation(proto.Message):
+ r"""Training pipeline will perform following transformation functions.
+
+ - The text as is--no change to case, punctuation, spelling, tense,
+ and so on.
+
+ - Convert the category name to a dictionary lookup index and
+ generate an embedding for each index.
+
+ Attributes:
+ column_name (str):
+
+ """
+
+ column_name = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+
+ auto = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ oneof="transformation_detail",
+ message="AutoMlForecastingInputs.Transformation.AutoTransformation",
+ )
+ numeric = proto.Field(
+ proto.MESSAGE,
+ number=2,
+ oneof="transformation_detail",
+ message="AutoMlForecastingInputs.Transformation.NumericTransformation",
+ )
+ categorical = proto.Field(
+ proto.MESSAGE,
+ number=3,
+ oneof="transformation_detail",
+ message="AutoMlForecastingInputs.Transformation.CategoricalTransformation",
+ )
+ timestamp = proto.Field(
+ proto.MESSAGE,
+ number=4,
+ oneof="transformation_detail",
+ message="AutoMlForecastingInputs.Transformation.TimestampTransformation",
+ )
+ text = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ oneof="transformation_detail",
+ message="AutoMlForecastingInputs.Transformation.TextTransformation",
+ )
+
+ class Granularity(proto.Message):
+ r"""A duration of time expressed in time granularity units.
+
+ Attributes:
+ unit (str):
+ The time granularity unit of this time period. The supported
+ units are:
+
+ - "minute"
+
+ - "hour"
+
+ - "day"
+
+ - "week"
+
+ - "month"
+
+ - "year".
+ quantity (int):
+ The number of granularity_units between data points in the
+ training data. If ``granularity_unit`` is ``minute``, can be
+ 1, 5, 10, 15, or 30. For all other values of
+ ``granularity_unit``, must be 1.
+ """
+
+ unit = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ quantity = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+
+ target_column = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ time_series_identifier_column = proto.Field(
+ proto.STRING,
+ number=2,
+ )
+ time_column = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+ transformations = proto.RepeatedField(
+ proto.MESSAGE,
+ number=4,
+ message=Transformation,
+ )
+ optimization_objective = proto.Field(
+ proto.STRING,
+ number=5,
+ )
+ train_budget_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=6,
+ )
+ weight_column = proto.Field(
+ proto.STRING,
+ number=7,
+ )
+ time_series_attribute_columns = proto.RepeatedField(
+ proto.STRING,
+ number=19,
+ )
+ unavailable_at_forecast_columns = proto.RepeatedField(
+ proto.STRING,
+ number=20,
+ )
+ available_at_forecast_columns = proto.RepeatedField(
+ proto.STRING,
+ number=21,
+ )
+ data_granularity = proto.Field(
+ proto.MESSAGE,
+ number=22,
+ message=Granularity,
+ )
+ forecast_horizon = proto.Field(
+ proto.INT64,
+ number=23,
+ )
+ context_window = proto.Field(
+ proto.INT64,
+ number=24,
+ )
+ export_evaluated_data_items_config = proto.Field(
+ proto.MESSAGE,
+ number=15,
+ message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig,
+ )
+ quantiles = proto.RepeatedField(
+ proto.DOUBLE,
+ number=16,
+ )
+ validation_options = proto.Field(
+ proto.STRING,
+ number=17,
+ )
+ additional_experiments = proto.RepeatedField(
+ proto.STRING,
+ number=25,
+ )
+
+
+class AutoMlForecastingMetadata(proto.Message):
+ r"""Model metadata specific to AutoML Forecasting.
+
+ Attributes:
+ train_cost_milli_node_hours (int):
+ Output only. The actual training cost of the
+ model, expressed in milli node hours, i.e. 1,000
+ value in this field means 1 node hour.
+ Guaranteed to not exceed the train budget.
+ """
+
+ train_cost_milli_node_hours = proto.Field(
+ proto.INT64,
+ number=1,
+ )
+
+
+__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py
index 4132a92bdc..d20cd00015 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",},
+ manifest={
+ "AutoMlVideoActionRecognition",
+ "AutoMlVideoActionRecognitionInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoActionRecognition(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoActionRecognitionInputs",
)
class AutoMlVideoActionRecognitionInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoActionRecognitionInputs.ModelType):
@@ -48,8 +54,14 @@ class ModelType(proto.Enum):
MODEL_TYPE_UNSPECIFIED = 0
CLOUD = 1
MOBILE_VERSATILE_1 = 2
+ MOBILE_JETSON_VERSATILE_1 = 3
+ MOBILE_CORAL_VERSATILE_1 = 4
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py
index f5860b0d16..767d27cbc3 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",},
+ manifest={
+ "AutoMlVideoClassification",
+ "AutoMlVideoClassificationInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoClassification(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoClassificationInputs",
)
class AutoMlVideoClassificationInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoClassificationInputs.ModelType):
@@ -50,7 +56,11 @@ class ModelType(proto.Enum):
MOBILE_VERSATILE_1 = 2
MOBILE_JETSON_VERSATILE_1 = 3
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py
index ea684c9977..bb000644df 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,10 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",},
+ manifest={
+ "AutoMlVideoObjectTracking",
+ "AutoMlVideoObjectTrackingInputs",
+ },
)
@@ -32,12 +35,15 @@ class AutoMlVideoObjectTracking(proto.Message):
"""
inputs = proto.Field(
- proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs",
+ proto.MESSAGE,
+ number=1,
+ message="AutoMlVideoObjectTrackingInputs",
)
class AutoMlVideoObjectTrackingInputs(proto.Message):
r"""
+
Attributes:
model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoObjectTrackingInputs.ModelType):
@@ -53,7 +59,11 @@ class ModelType(proto.Enum):
MOBILE_JETSON_VERSATILE_1 = 5
MOBILE_JETSON_LOW_LATENCY_1 = 6
- model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,)
+ model_type = proto.Field(
+ proto.ENUM,
+ number=1,
+ enum=ModelType,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py
index 15046f72c1..b8820db522 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,9 @@
__protobuf__ = proto.module(
package="google.cloud.aiplatform.v1beta1.schema.trainingjob.definition",
- manifest={"ExportEvaluatedDataItemsConfig",},
+ manifest={
+ "ExportEvaluatedDataItemsConfig",
+ },
)
@@ -33,7 +35,6 @@ class ExportEvaluatedDataItemsConfig(proto.Message):
If not specified, then results are exported to the following
auto-created BigQuery table:
-
:export_evaluated_examples__.evaluated_examples
override_existing_table (bool):
If true and an export destination is
@@ -43,8 +44,14 @@ class ExportEvaluatedDataItemsConfig(proto.Message):
operation fails.
"""
- destination_bigquery_uri = proto.Field(proto.STRING, number=1,)
- override_existing_table = proto.Field(proto.BOOL, number=2,)
+ destination_bigquery_uri = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ override_existing_table = proto.Field(
+ proto.BOOL,
+ number=2,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py
new file mode 100644
index 0000000000..21995202bd
--- /dev/null
+++ b/google/cloud/aiplatform/version.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__version__ = "1.14.0"
diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py
index d765cc599d..d4e36a264c 100644
--- a/google/cloud/aiplatform_v1/__init__.py
+++ b/google/cloud/aiplatform_v1/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,22 @@
from .services.dataset_service import DatasetServiceAsyncClient
from .services.endpoint_service import EndpointServiceClient
from .services.endpoint_service import EndpointServiceAsyncClient
+from .services.featurestore_online_serving_service import (
+ FeaturestoreOnlineServingServiceClient,
+)
+from .services.featurestore_online_serving_service import (
+ FeaturestoreOnlineServingServiceAsyncClient,
+)
+from .services.featurestore_service import FeaturestoreServiceClient
+from .services.featurestore_service import FeaturestoreServiceAsyncClient
+from .services.index_endpoint_service import IndexEndpointServiceClient
+from .services.index_endpoint_service import IndexEndpointServiceAsyncClient
+from .services.index_service import IndexServiceClient
+from .services.index_service import IndexServiceAsyncClient
from .services.job_service import JobServiceClient
from .services.job_service import JobServiceAsyncClient
+from .services.metadata_service import MetadataServiceClient
+from .services.metadata_service import MetadataServiceAsyncClient
from .services.migration_service import MigrationServiceClient
from .services.migration_service import MigrationServiceAsyncClient
from .services.model_service import ModelServiceClient
@@ -30,12 +44,18 @@
from .services.prediction_service import PredictionServiceAsyncClient
from .services.specialist_pool_service import SpecialistPoolServiceClient
from .services.specialist_pool_service import SpecialistPoolServiceAsyncClient
+from .services.tensorboard_service import TensorboardServiceClient
+from .services.tensorboard_service import TensorboardServiceAsyncClient
+from .services.vizier_service import VizierServiceClient
+from .services.vizier_service import VizierServiceAsyncClient
from .types.accelerator_type import AcceleratorType
from .types.annotation import Annotation
from .types.annotation_spec import AnnotationSpec
+from .types.artifact import Artifact
from .types.batch_prediction_job import BatchPredictionJob
from .types.completion_stats import CompletionStats
+from .types.context import Context
from .types.custom_job import ContainerSpec
from .types.custom_job import CustomJob
from .types.custom_job import CustomJobSpec
@@ -68,10 +88,13 @@
from .types.dataset_service import ListDatasetsRequest
from .types.dataset_service import ListDatasetsResponse
from .types.dataset_service import UpdateDatasetRequest
+from .types.deployed_index_ref import DeployedIndexRef
from .types.deployed_model_ref import DeployedModelRef
from .types.encryption_spec import EncryptionSpec
from .types.endpoint import DeployedModel
from .types.endpoint import Endpoint
+from .types.endpoint import PredictRequestResponseLoggingConfig
+from .types.endpoint import PrivateEndpoints
from .types.endpoint_service import CreateEndpointOperationMetadata
from .types.endpoint_service import CreateEndpointRequest
from .types.endpoint_service import DeleteEndpointRequest
@@ -85,13 +108,116 @@
from .types.endpoint_service import UndeployModelRequest
from .types.endpoint_service import UndeployModelResponse
from .types.endpoint_service import UpdateEndpointRequest
+from .types.entity_type import EntityType
from .types.env_var import EnvVar
+from .types.event import Event
+from .types.execution import Execution
+from .types.explanation import Attribution
+from .types.explanation import BlurBaselineConfig
+from .types.explanation import ExamplesOverride
+from .types.explanation import ExamplesRestrictionsNamespace
+from .types.explanation import Explanation
+from .types.explanation import ExplanationMetadataOverride
+from .types.explanation import ExplanationParameters
+from .types.explanation import ExplanationSpec
+from .types.explanation import ExplanationSpecOverride
+from .types.explanation import FeatureNoiseSigma
+from .types.explanation import IntegratedGradientsAttribution
+from .types.explanation import ModelExplanation
+from .types.explanation import Neighbor
+from .types.explanation import SampledShapleyAttribution
+from .types.explanation import SmoothGradConfig
+from .types.explanation import XraiAttribution
+from .types.explanation_metadata import ExplanationMetadata
+from .types.feature import Feature
+from .types.feature_monitoring_stats import FeatureStatsAnomaly
+from .types.feature_selector import FeatureSelector
+from .types.feature_selector import IdMatcher
+from .types.featurestore import Featurestore
+from .types.featurestore_monitoring import FeaturestoreMonitoringConfig
+from .types.featurestore_online_service import FeatureValue
+from .types.featurestore_online_service import FeatureValueList
+from .types.featurestore_online_service import ReadFeatureValuesRequest
+from .types.featurestore_online_service import ReadFeatureValuesResponse
+from .types.featurestore_online_service import StreamingReadFeatureValuesRequest
+from .types.featurestore_service import BatchCreateFeaturesOperationMetadata
+from .types.featurestore_service import BatchCreateFeaturesRequest
+from .types.featurestore_service import BatchCreateFeaturesResponse
+from .types.featurestore_service import BatchReadFeatureValuesOperationMetadata
+from .types.featurestore_service import BatchReadFeatureValuesRequest
+from .types.featurestore_service import BatchReadFeatureValuesResponse
+from .types.featurestore_service import CreateEntityTypeOperationMetadata
+from .types.featurestore_service import CreateEntityTypeRequest
+from .types.featurestore_service import CreateFeatureOperationMetadata
+from .types.featurestore_service import CreateFeatureRequest
+from .types.featurestore_service import CreateFeaturestoreOperationMetadata
+from .types.featurestore_service import CreateFeaturestoreRequest
+from .types.featurestore_service import DeleteEntityTypeRequest
+from .types.featurestore_service import DeleteFeatureRequest
+from .types.featurestore_service import DeleteFeaturestoreRequest
+from .types.featurestore_service import DestinationFeatureSetting
+from .types.featurestore_service import ExportFeatureValuesOperationMetadata
+from .types.featurestore_service import ExportFeatureValuesRequest
+from .types.featurestore_service import ExportFeatureValuesResponse
+from .types.featurestore_service import FeatureValueDestination
+from .types.featurestore_service import GetEntityTypeRequest
+from .types.featurestore_service import GetFeatureRequest
+from .types.featurestore_service import GetFeaturestoreRequest
+from .types.featurestore_service import ImportFeatureValuesOperationMetadata
+from .types.featurestore_service import ImportFeatureValuesRequest
+from .types.featurestore_service import ImportFeatureValuesResponse
+from .types.featurestore_service import ListEntityTypesRequest
+from .types.featurestore_service import ListEntityTypesResponse
+from .types.featurestore_service import ListFeaturesRequest
+from .types.featurestore_service import ListFeaturesResponse
+from .types.featurestore_service import ListFeaturestoresRequest
+from .types.featurestore_service import ListFeaturestoresResponse
+from .types.featurestore_service import SearchFeaturesRequest
+from .types.featurestore_service import SearchFeaturesResponse
+from .types.featurestore_service import UpdateEntityTypeRequest
+from .types.featurestore_service import UpdateFeatureRequest
+from .types.featurestore_service import UpdateFeaturestoreOperationMetadata
+from .types.featurestore_service import UpdateFeaturestoreRequest
from .types.hyperparameter_tuning_job import HyperparameterTuningJob
+from .types.index import Index
+from .types.index_endpoint import DeployedIndex
+from .types.index_endpoint import DeployedIndexAuthConfig
+from .types.index_endpoint import IndexEndpoint
+from .types.index_endpoint import IndexPrivateEndpoints
+from .types.index_endpoint_service import CreateIndexEndpointOperationMetadata
+from .types.index_endpoint_service import CreateIndexEndpointRequest
+from .types.index_endpoint_service import DeleteIndexEndpointRequest
+from .types.index_endpoint_service import DeployIndexOperationMetadata
+from .types.index_endpoint_service import DeployIndexRequest
+from .types.index_endpoint_service import DeployIndexResponse
+from .types.index_endpoint_service import GetIndexEndpointRequest
+from .types.index_endpoint_service import ListIndexEndpointsRequest
+from .types.index_endpoint_service import ListIndexEndpointsResponse
+from .types.index_endpoint_service import MutateDeployedIndexOperationMetadata
+from .types.index_endpoint_service import MutateDeployedIndexRequest
+from .types.index_endpoint_service import MutateDeployedIndexResponse
+from .types.index_endpoint_service import UndeployIndexOperationMetadata
+from .types.index_endpoint_service import UndeployIndexRequest
+from .types.index_endpoint_service import UndeployIndexResponse
+from .types.index_endpoint_service import UpdateIndexEndpointRequest
+from .types.index_service import CreateIndexOperationMetadata
+from .types.index_service import CreateIndexRequest
+from .types.index_service import DeleteIndexRequest
+from .types.index_service import GetIndexRequest
+from .types.index_service import ListIndexesRequest
+from .types.index_service import ListIndexesResponse
+from .types.index_service import NearestNeighborSearchOperationMetadata
+from .types.index_service import UpdateIndexOperationMetadata
+from .types.index_service import UpdateIndexRequest
+from .types.io import AvroSource
from .types.io import BigQueryDestination
from .types.io import BigQuerySource
from .types.io import ContainerRegistryDestination
+from .types.io import CsvDestination
+from .types.io import CsvSource
from .types.io import GcsDestination
from .types.io import GcsSource
+from .types.io import TFRecordDestination
from .types.job_service import CancelBatchPredictionJobRequest
from .types.job_service import CancelCustomJobRequest
from .types.job_service import CancelDataLabelingJobRequest
@@ -100,14 +226,17 @@
from .types.job_service import CreateCustomJobRequest
from .types.job_service import CreateDataLabelingJobRequest
from .types.job_service import CreateHyperparameterTuningJobRequest
+from .types.job_service import CreateModelDeploymentMonitoringJobRequest
from .types.job_service import DeleteBatchPredictionJobRequest
from .types.job_service import DeleteCustomJobRequest
from .types.job_service import DeleteDataLabelingJobRequest
from .types.job_service import DeleteHyperparameterTuningJobRequest
+from .types.job_service import DeleteModelDeploymentMonitoringJobRequest
from .types.job_service import GetBatchPredictionJobRequest
from .types.job_service import GetCustomJobRequest
from .types.job_service import GetDataLabelingJobRequest
from .types.job_service import GetHyperparameterTuningJobRequest
+from .types.job_service import GetModelDeploymentMonitoringJobRequest
from .types.job_service import ListBatchPredictionJobsRequest
from .types.job_service import ListBatchPredictionJobsResponse
from .types.job_service import ListCustomJobsRequest
@@ -116,14 +245,74 @@
from .types.job_service import ListDataLabelingJobsResponse
from .types.job_service import ListHyperparameterTuningJobsRequest
from .types.job_service import ListHyperparameterTuningJobsResponse
+from .types.job_service import ListModelDeploymentMonitoringJobsRequest
+from .types.job_service import ListModelDeploymentMonitoringJobsResponse
+from .types.job_service import PauseModelDeploymentMonitoringJobRequest
+from .types.job_service import ResumeModelDeploymentMonitoringJobRequest
+from .types.job_service import SearchModelDeploymentMonitoringStatsAnomaliesRequest
+from .types.job_service import SearchModelDeploymentMonitoringStatsAnomaliesResponse
+from .types.job_service import UpdateModelDeploymentMonitoringJobOperationMetadata
+from .types.job_service import UpdateModelDeploymentMonitoringJobRequest
from .types.job_state import JobState
+from .types.lineage_subgraph import LineageSubgraph
from .types.machine_resources import AutomaticResources
+from .types.machine_resources import AutoscalingMetricSpec
from .types.machine_resources import BatchDedicatedResources
from .types.machine_resources import DedicatedResources
from .types.machine_resources import DiskSpec
from .types.machine_resources import MachineSpec
+from .types.machine_resources import NfsMount
from .types.machine_resources import ResourcesConsumed
from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters
+from .types.metadata_schema import MetadataSchema
+from .types.metadata_service import AddContextArtifactsAndExecutionsRequest
+from .types.metadata_service import AddContextArtifactsAndExecutionsResponse
+from .types.metadata_service import AddContextChildrenRequest
+from .types.metadata_service import AddContextChildrenResponse
+from .types.metadata_service import AddExecutionEventsRequest
+from .types.metadata_service import AddExecutionEventsResponse
+from .types.metadata_service import CreateArtifactRequest
+from .types.metadata_service import CreateContextRequest
+from .types.metadata_service import CreateExecutionRequest
+from .types.metadata_service import CreateMetadataSchemaRequest
+from .types.metadata_service import CreateMetadataStoreOperationMetadata
+from .types.metadata_service import CreateMetadataStoreRequest
+from .types.metadata_service import DeleteArtifactRequest
+from .types.metadata_service import DeleteContextRequest
+from .types.metadata_service import DeleteExecutionRequest
+from .types.metadata_service import DeleteMetadataStoreOperationMetadata
+from .types.metadata_service import DeleteMetadataStoreRequest
+from .types.metadata_service import GetArtifactRequest
+from .types.metadata_service import GetContextRequest
+from .types.metadata_service import GetExecutionRequest
+from .types.metadata_service import GetMetadataSchemaRequest
+from .types.metadata_service import GetMetadataStoreRequest
+from .types.metadata_service import ListArtifactsRequest
+from .types.metadata_service import ListArtifactsResponse
+from .types.metadata_service import ListContextsRequest
+from .types.metadata_service import ListContextsResponse
+from .types.metadata_service import ListExecutionsRequest
+from .types.metadata_service import ListExecutionsResponse
+from .types.metadata_service import ListMetadataSchemasRequest
+from .types.metadata_service import ListMetadataSchemasResponse
+from .types.metadata_service import ListMetadataStoresRequest
+from .types.metadata_service import ListMetadataStoresResponse
+from .types.metadata_service import PurgeArtifactsMetadata
+from .types.metadata_service import PurgeArtifactsRequest
+from .types.metadata_service import PurgeArtifactsResponse
+from .types.metadata_service import PurgeContextsMetadata
+from .types.metadata_service import PurgeContextsRequest
+from .types.metadata_service import PurgeContextsResponse
+from .types.metadata_service import PurgeExecutionsMetadata
+from .types.metadata_service import PurgeExecutionsRequest
+from .types.metadata_service import PurgeExecutionsResponse
+from .types.metadata_service import QueryArtifactLineageSubgraphRequest
+from .types.metadata_service import QueryContextLineageSubgraphRequest
+from .types.metadata_service import QueryExecutionInputsAndOutputsRequest
+from .types.metadata_service import UpdateArtifactRequest
+from .types.metadata_service import UpdateContextRequest
+from .types.metadata_service import UpdateExecutionRequest
+from .types.metadata_store import MetadataStore
from .types.migratable_resource import MigratableResource
from .types.migration_service import BatchMigrateResourcesOperationMetadata
from .types.migration_service import BatchMigrateResourcesRequest
@@ -136,36 +325,74 @@
from .types.model import ModelContainerSpec
from .types.model import Port
from .types.model import PredictSchemata
+from .types.model_deployment_monitoring_job import (
+ ModelDeploymentMonitoringBigQueryTable,
+)
+from .types.model_deployment_monitoring_job import ModelDeploymentMonitoringJob
+from .types.model_deployment_monitoring_job import (
+ ModelDeploymentMonitoringObjectiveConfig,
+)
+from .types.model_deployment_monitoring_job import (
+ ModelDeploymentMonitoringScheduleConfig,
+)
+from .types.model_deployment_monitoring_job import ModelMonitoringStatsAnomalies
+from .types.model_deployment_monitoring_job import (
+ ModelDeploymentMonitoringObjectiveType,
+)
from .types.model_evaluation import ModelEvaluation
from .types.model_evaluation_slice import ModelEvaluationSlice
+from .types.model_monitoring import ModelMonitoringAlertConfig
+from .types.model_monitoring import ModelMonitoringObjectiveConfig
+from .types.model_monitoring import SamplingStrategy
+from .types.model_monitoring import ThresholdConfig
from .types.model_service import DeleteModelRequest
+from .types.model_service import DeleteModelVersionRequest
from .types.model_service import ExportModelOperationMetadata
from .types.model_service import ExportModelRequest
from .types.model_service import ExportModelResponse
from .types.model_service import GetModelEvaluationRequest
from .types.model_service import GetModelEvaluationSliceRequest
from .types.model_service import GetModelRequest
+from .types.model_service import ImportModelEvaluationRequest
from .types.model_service import ListModelEvaluationSlicesRequest
from .types.model_service import ListModelEvaluationSlicesResponse
from .types.model_service import ListModelEvaluationsRequest
from .types.model_service import ListModelEvaluationsResponse
from .types.model_service import ListModelsRequest
from .types.model_service import ListModelsResponse
+from .types.model_service import ListModelVersionsRequest
+from .types.model_service import ListModelVersionsResponse
+from .types.model_service import MergeVersionAliasesRequest
from .types.model_service import UpdateModelRequest
from .types.model_service import UploadModelOperationMetadata
from .types.model_service import UploadModelRequest
from .types.model_service import UploadModelResponse
from .types.operation import DeleteOperationMetadata
from .types.operation import GenericOperationMetadata
+from .types.pipeline_failure_policy import PipelineFailurePolicy
+from .types.pipeline_job import PipelineJob
+from .types.pipeline_job import PipelineJobDetail
+from .types.pipeline_job import PipelineTaskDetail
+from .types.pipeline_job import PipelineTaskExecutorDetail
+from .types.pipeline_job import PipelineTemplateMetadata
+from .types.pipeline_service import CancelPipelineJobRequest
from .types.pipeline_service import CancelTrainingPipelineRequest
+from .types.pipeline_service import CreatePipelineJobRequest
from .types.pipeline_service import CreateTrainingPipelineRequest
+from .types.pipeline_service import DeletePipelineJobRequest
from .types.pipeline_service import DeleteTrainingPipelineRequest
+from .types.pipeline_service import GetPipelineJobRequest
from .types.pipeline_service import GetTrainingPipelineRequest
+from .types.pipeline_service import ListPipelineJobsRequest
+from .types.pipeline_service import ListPipelineJobsResponse
from .types.pipeline_service import ListTrainingPipelinesRequest
from .types.pipeline_service import ListTrainingPipelinesResponse
from .types.pipeline_state import PipelineState
+from .types.prediction_service import ExplainRequest
+from .types.prediction_service import ExplainResponse
from .types.prediction_service import PredictRequest
from .types.prediction_service import PredictResponse
+from .types.prediction_service import RawPredictRequest
from .types.specialist_pool import SpecialistPool
from .types.specialist_pool_service import CreateSpecialistPoolOperationMetadata
from .types.specialist_pool_service import CreateSpecialistPoolRequest
@@ -176,56 +403,203 @@
from .types.specialist_pool_service import UpdateSpecialistPoolOperationMetadata
from .types.specialist_pool_service import UpdateSpecialistPoolRequest
from .types.study import Measurement
+from .types.study import Study
from .types.study import StudySpec
from .types.study import Trial
+from .types.tensorboard import Tensorboard
+from .types.tensorboard_data import Scalar
+from .types.tensorboard_data import TensorboardBlob
+from .types.tensorboard_data import TensorboardBlobSequence
+from .types.tensorboard_data import TensorboardTensor
+from .types.tensorboard_data import TimeSeriesData
+from .types.tensorboard_data import TimeSeriesDataPoint
+from .types.tensorboard_experiment import TensorboardExperiment
+from .types.tensorboard_run import TensorboardRun
+from .types.tensorboard_service import BatchCreateTensorboardRunsRequest
+from .types.tensorboard_service import BatchCreateTensorboardRunsResponse
+from .types.tensorboard_service import BatchCreateTensorboardTimeSeriesRequest
+from .types.tensorboard_service import BatchCreateTensorboardTimeSeriesResponse
+from .types.tensorboard_service import BatchReadTensorboardTimeSeriesDataRequest
+from .types.tensorboard_service import BatchReadTensorboardTimeSeriesDataResponse
+from .types.tensorboard_service import CreateTensorboardExperimentRequest
+from .types.tensorboard_service import CreateTensorboardOperationMetadata
+from .types.tensorboard_service import CreateTensorboardRequest
+from .types.tensorboard_service import CreateTensorboardRunRequest
+from .types.tensorboard_service import CreateTensorboardTimeSeriesRequest
+from .types.tensorboard_service import DeleteTensorboardExperimentRequest
+from .types.tensorboard_service import DeleteTensorboardRequest
+from .types.tensorboard_service import DeleteTensorboardRunRequest
+from .types.tensorboard_service import DeleteTensorboardTimeSeriesRequest
+from .types.tensorboard_service import ExportTensorboardTimeSeriesDataRequest
+from .types.tensorboard_service import ExportTensorboardTimeSeriesDataResponse
+from .types.tensorboard_service import GetTensorboardExperimentRequest
+from .types.tensorboard_service import GetTensorboardRequest
+from .types.tensorboard_service import GetTensorboardRunRequest
+from .types.tensorboard_service import GetTensorboardTimeSeriesRequest
+from .types.tensorboard_service import ListTensorboardExperimentsRequest
+from .types.tensorboard_service import ListTensorboardExperimentsResponse
+from .types.tensorboard_service import ListTensorboardRunsRequest
+from .types.tensorboard_service import ListTensorboardRunsResponse
+from .types.tensorboard_service import ListTensorboardsRequest
+from .types.tensorboard_service import ListTensorboardsResponse
+from .types.tensorboard_service import ListTensorboardTimeSeriesRequest
+from .types.tensorboard_service import ListTensorboardTimeSeriesResponse
+from .types.tensorboard_service import ReadTensorboardBlobDataRequest
+from .types.tensorboard_service import ReadTensorboardBlobDataResponse
+from .types.tensorboard_service import ReadTensorboardTimeSeriesDataRequest
+from .types.tensorboard_service import ReadTensorboardTimeSeriesDataResponse
+from .types.tensorboard_service import UpdateTensorboardExperimentRequest
+from .types.tensorboard_service import UpdateTensorboardOperationMetadata
+from .types.tensorboard_service import UpdateTensorboardRequest
+from .types.tensorboard_service import UpdateTensorboardRunRequest
+from .types.tensorboard_service import UpdateTensorboardTimeSeriesRequest
+from .types.tensorboard_service import WriteTensorboardExperimentDataRequest
+from .types.tensorboard_service import WriteTensorboardExperimentDataResponse
+from .types.tensorboard_service import WriteTensorboardRunDataRequest
+from .types.tensorboard_service import WriteTensorboardRunDataResponse
+from .types.tensorboard_time_series import TensorboardTimeSeries
from .types.training_pipeline import FilterSplit
from .types.training_pipeline import FractionSplit
from .types.training_pipeline import InputDataConfig
from .types.training_pipeline import PredefinedSplit
+from .types.training_pipeline import StratifiedSplit
from .types.training_pipeline import TimestampSplit
from .types.training_pipeline import TrainingPipeline
+from .types.types import BoolArray
+from .types.types import DoubleArray
+from .types.types import Int64Array
+from .types.types import StringArray
+from .types.unmanaged_container_model import UnmanagedContainerModel
from .types.user_action_reference import UserActionReference
+from .types.value import Value
+from .types.vizier_service import AddTrialMeasurementRequest
+from .types.vizier_service import CheckTrialEarlyStoppingStateMetatdata
+from .types.vizier_service import CheckTrialEarlyStoppingStateRequest
+from .types.vizier_service import CheckTrialEarlyStoppingStateResponse
+from .types.vizier_service import CompleteTrialRequest
+from .types.vizier_service import CreateStudyRequest
+from .types.vizier_service import CreateTrialRequest
+from .types.vizier_service import DeleteStudyRequest
+from .types.vizier_service import DeleteTrialRequest
+from .types.vizier_service import GetStudyRequest
+from .types.vizier_service import GetTrialRequest
+from .types.vizier_service import ListOptimalTrialsRequest
+from .types.vizier_service import ListOptimalTrialsResponse
+from .types.vizier_service import ListStudiesRequest
+from .types.vizier_service import ListStudiesResponse
+from .types.vizier_service import ListTrialsRequest
+from .types.vizier_service import ListTrialsResponse
+from .types.vizier_service import LookupStudyRequest
+from .types.vizier_service import StopTrialRequest
+from .types.vizier_service import SuggestTrialsMetadata
+from .types.vizier_service import SuggestTrialsRequest
+from .types.vizier_service import SuggestTrialsResponse
__all__ = (
"DatasetServiceAsyncClient",
"EndpointServiceAsyncClient",
+ "FeaturestoreOnlineServingServiceAsyncClient",
+ "FeaturestoreServiceAsyncClient",
+ "IndexEndpointServiceAsyncClient",
+ "IndexServiceAsyncClient",
"JobServiceAsyncClient",
+ "MetadataServiceAsyncClient",
"MigrationServiceAsyncClient",
"ModelServiceAsyncClient",
"PipelineServiceAsyncClient",
"PredictionServiceAsyncClient",
"SpecialistPoolServiceAsyncClient",
+ "TensorboardServiceAsyncClient",
+ "VizierServiceAsyncClient",
"AcceleratorType",
"ActiveLearningConfig",
+ "AddContextArtifactsAndExecutionsRequest",
+ "AddContextArtifactsAndExecutionsResponse",
+ "AddContextChildrenRequest",
+ "AddContextChildrenResponse",
+ "AddExecutionEventsRequest",
+ "AddExecutionEventsResponse",
+ "AddTrialMeasurementRequest",
"Annotation",
"AnnotationSpec",
+ "Artifact",
+ "Attribution",
"AutomaticResources",
+ "AutoscalingMetricSpec",
+ "AvroSource",
+ "BatchCreateFeaturesOperationMetadata",
+ "BatchCreateFeaturesRequest",
+ "BatchCreateFeaturesResponse",
+ "BatchCreateTensorboardRunsRequest",
+ "BatchCreateTensorboardRunsResponse",
+ "BatchCreateTensorboardTimeSeriesRequest",
+ "BatchCreateTensorboardTimeSeriesResponse",
"BatchDedicatedResources",
"BatchMigrateResourcesOperationMetadata",
"BatchMigrateResourcesRequest",
"BatchMigrateResourcesResponse",
"BatchPredictionJob",
+ "BatchReadFeatureValuesOperationMetadata",
+ "BatchReadFeatureValuesRequest",
+ "BatchReadFeatureValuesResponse",
+ "BatchReadTensorboardTimeSeriesDataRequest",
+ "BatchReadTensorboardTimeSeriesDataResponse",
"BigQueryDestination",
"BigQuerySource",
+ "BlurBaselineConfig",
+ "BoolArray",
"CancelBatchPredictionJobRequest",
"CancelCustomJobRequest",
"CancelDataLabelingJobRequest",
"CancelHyperparameterTuningJobRequest",
+ "CancelPipelineJobRequest",
"CancelTrainingPipelineRequest",
+ "CheckTrialEarlyStoppingStateMetatdata",
+ "CheckTrialEarlyStoppingStateRequest",
+ "CheckTrialEarlyStoppingStateResponse",
+ "CompleteTrialRequest",
"CompletionStats",
"ContainerRegistryDestination",
"ContainerSpec",
+ "Context",
+ "CreateArtifactRequest",
"CreateBatchPredictionJobRequest",
+ "CreateContextRequest",
"CreateCustomJobRequest",
"CreateDataLabelingJobRequest",
"CreateDatasetOperationMetadata",
"CreateDatasetRequest",
"CreateEndpointOperationMetadata",
"CreateEndpointRequest",
+ "CreateEntityTypeOperationMetadata",
+ "CreateEntityTypeRequest",
+ "CreateExecutionRequest",
+ "CreateFeatureOperationMetadata",
+ "CreateFeatureRequest",
+ "CreateFeaturestoreOperationMetadata",
+ "CreateFeaturestoreRequest",
"CreateHyperparameterTuningJobRequest",
+ "CreateIndexEndpointOperationMetadata",
+ "CreateIndexEndpointRequest",
+ "CreateIndexOperationMetadata",
+ "CreateIndexRequest",
+ "CreateMetadataSchemaRequest",
+ "CreateMetadataStoreOperationMetadata",
+ "CreateMetadataStoreRequest",
+ "CreateModelDeploymentMonitoringJobRequest",
+ "CreatePipelineJobRequest",
"CreateSpecialistPoolOperationMetadata",
"CreateSpecialistPoolRequest",
+ "CreateStudyRequest",
+ "CreateTensorboardExperimentRequest",
+ "CreateTensorboardOperationMetadata",
+ "CreateTensorboardRequest",
+ "CreateTensorboardRunRequest",
+ "CreateTensorboardTimeSeriesRequest",
"CreateTrainingPipelineRequest",
+ "CreateTrialRequest",
+ "CsvDestination",
+ "CsvSource",
"CustomJob",
"CustomJobSpec",
"DataItem",
@@ -233,62 +607,153 @@
"Dataset",
"DatasetServiceClient",
"DedicatedResources",
+ "DeleteArtifactRequest",
"DeleteBatchPredictionJobRequest",
+ "DeleteContextRequest",
"DeleteCustomJobRequest",
"DeleteDataLabelingJobRequest",
"DeleteDatasetRequest",
"DeleteEndpointRequest",
+ "DeleteEntityTypeRequest",
+ "DeleteExecutionRequest",
+ "DeleteFeatureRequest",
+ "DeleteFeaturestoreRequest",
"DeleteHyperparameterTuningJobRequest",
+ "DeleteIndexEndpointRequest",
+ "DeleteIndexRequest",
+ "DeleteMetadataStoreOperationMetadata",
+ "DeleteMetadataStoreRequest",
+ "DeleteModelDeploymentMonitoringJobRequest",
"DeleteModelRequest",
+ "DeleteModelVersionRequest",
"DeleteOperationMetadata",
+ "DeletePipelineJobRequest",
"DeleteSpecialistPoolRequest",
+ "DeleteStudyRequest",
+ "DeleteTensorboardExperimentRequest",
+ "DeleteTensorboardRequest",
+ "DeleteTensorboardRunRequest",
+ "DeleteTensorboardTimeSeriesRequest",
"DeleteTrainingPipelineRequest",
+ "DeleteTrialRequest",
+ "DeployIndexOperationMetadata",
+ "DeployIndexRequest",
+ "DeployIndexResponse",
"DeployModelOperationMetadata",
"DeployModelRequest",
"DeployModelResponse",
+ "DeployedIndex",
+ "DeployedIndexAuthConfig",
+ "DeployedIndexRef",
"DeployedModel",
"DeployedModelRef",
+ "DestinationFeatureSetting",
"DiskSpec",
+ "DoubleArray",
"EncryptionSpec",
"Endpoint",
"EndpointServiceClient",
+ "EntityType",
"EnvVar",
+ "Event",
+ "ExamplesOverride",
+ "ExamplesRestrictionsNamespace",
+ "Execution",
+ "ExplainRequest",
+ "ExplainResponse",
+ "Explanation",
+ "ExplanationMetadata",
+ "ExplanationMetadataOverride",
+ "ExplanationParameters",
+ "ExplanationSpec",
+ "ExplanationSpecOverride",
"ExportDataConfig",
"ExportDataOperationMetadata",
"ExportDataRequest",
"ExportDataResponse",
+ "ExportFeatureValuesOperationMetadata",
+ "ExportFeatureValuesRequest",
+ "ExportFeatureValuesResponse",
"ExportModelOperationMetadata",
"ExportModelRequest",
"ExportModelResponse",
+ "ExportTensorboardTimeSeriesDataRequest",
+ "ExportTensorboardTimeSeriesDataResponse",
+ "Feature",
+ "FeatureNoiseSigma",
+ "FeatureSelector",
+ "FeatureStatsAnomaly",
+ "FeatureValue",
+ "FeatureValueDestination",
+ "FeatureValueList",
+ "Featurestore",
+ "FeaturestoreMonitoringConfig",
+ "FeaturestoreOnlineServingServiceClient",
+ "FeaturestoreServiceClient",
"FilterSplit",
"FractionSplit",
"GcsDestination",
"GcsSource",
"GenericOperationMetadata",
"GetAnnotationSpecRequest",
+ "GetArtifactRequest",
"GetBatchPredictionJobRequest",
+ "GetContextRequest",
"GetCustomJobRequest",
"GetDataLabelingJobRequest",
"GetDatasetRequest",
"GetEndpointRequest",
+ "GetEntityTypeRequest",
+ "GetExecutionRequest",
+ "GetFeatureRequest",
+ "GetFeaturestoreRequest",
"GetHyperparameterTuningJobRequest",
+ "GetIndexEndpointRequest",
+ "GetIndexRequest",
+ "GetMetadataSchemaRequest",
+ "GetMetadataStoreRequest",
+ "GetModelDeploymentMonitoringJobRequest",
"GetModelEvaluationRequest",
"GetModelEvaluationSliceRequest",
"GetModelRequest",
+ "GetPipelineJobRequest",
"GetSpecialistPoolRequest",
+ "GetStudyRequest",
+ "GetTensorboardExperimentRequest",
+ "GetTensorboardRequest",
+ "GetTensorboardRunRequest",
+ "GetTensorboardTimeSeriesRequest",
"GetTrainingPipelineRequest",
+ "GetTrialRequest",
"HyperparameterTuningJob",
+ "IdMatcher",
"ImportDataConfig",
"ImportDataOperationMetadata",
"ImportDataRequest",
"ImportDataResponse",
+ "ImportFeatureValuesOperationMetadata",
+ "ImportFeatureValuesRequest",
+ "ImportFeatureValuesResponse",
+ "ImportModelEvaluationRequest",
+ "Index",
+ "IndexEndpoint",
+ "IndexEndpointServiceClient",
+ "IndexPrivateEndpoints",
+ "IndexServiceClient",
"InputDataConfig",
+ "Int64Array",
+ "IntegratedGradientsAttribution",
"JobServiceClient",
"JobState",
+ "LineageSubgraph",
"ListAnnotationsRequest",
"ListAnnotationsResponse",
+ "ListArtifactsRequest",
+ "ListArtifactsResponse",
"ListBatchPredictionJobsRequest",
"ListBatchPredictionJobsResponse",
+ "ListContextsRequest",
+ "ListContextsResponse",
"ListCustomJobsRequest",
"ListCustomJobsResponse",
"ListDataItemsRequest",
@@ -299,62 +764,203 @@
"ListDatasetsResponse",
"ListEndpointsRequest",
"ListEndpointsResponse",
+ "ListEntityTypesRequest",
+ "ListEntityTypesResponse",
+ "ListExecutionsRequest",
+ "ListExecutionsResponse",
+ "ListFeaturesRequest",
+ "ListFeaturesResponse",
+ "ListFeaturestoresRequest",
+ "ListFeaturestoresResponse",
"ListHyperparameterTuningJobsRequest",
"ListHyperparameterTuningJobsResponse",
+ "ListIndexEndpointsRequest",
+ "ListIndexEndpointsResponse",
+ "ListIndexesRequest",
+ "ListIndexesResponse",
+ "ListMetadataSchemasRequest",
+ "ListMetadataSchemasResponse",
+ "ListMetadataStoresRequest",
+ "ListMetadataStoresResponse",
+ "ListModelDeploymentMonitoringJobsRequest",
+ "ListModelDeploymentMonitoringJobsResponse",
"ListModelEvaluationSlicesRequest",
"ListModelEvaluationSlicesResponse",
"ListModelEvaluationsRequest",
"ListModelEvaluationsResponse",
+ "ListModelVersionsRequest",
+ "ListModelVersionsResponse",
"ListModelsRequest",
"ListModelsResponse",
+ "ListOptimalTrialsRequest",
+ "ListOptimalTrialsResponse",
+ "ListPipelineJobsRequest",
+ "ListPipelineJobsResponse",
"ListSpecialistPoolsRequest",
"ListSpecialistPoolsResponse",
+ "ListStudiesRequest",
+ "ListStudiesResponse",
+ "ListTensorboardExperimentsRequest",
+ "ListTensorboardExperimentsResponse",
+ "ListTensorboardRunsRequest",
+ "ListTensorboardRunsResponse",
+ "ListTensorboardTimeSeriesRequest",
+ "ListTensorboardTimeSeriesResponse",
+ "ListTensorboardsRequest",
+ "ListTensorboardsResponse",
"ListTrainingPipelinesRequest",
"ListTrainingPipelinesResponse",
+ "ListTrialsRequest",
+ "ListTrialsResponse",
+ "LookupStudyRequest",
"MachineSpec",
"ManualBatchTuningParameters",
"Measurement",
+ "MergeVersionAliasesRequest",
+ "MetadataSchema",
+ "MetadataServiceClient",
+ "MetadataStore",
"MigratableResource",
"MigrateResourceRequest",
"MigrateResourceResponse",
"MigrationServiceClient",
"Model",
"ModelContainerSpec",
+ "ModelDeploymentMonitoringBigQueryTable",
+ "ModelDeploymentMonitoringJob",
+ "ModelDeploymentMonitoringObjectiveConfig",
+ "ModelDeploymentMonitoringObjectiveType",
+ "ModelDeploymentMonitoringScheduleConfig",
"ModelEvaluation",
"ModelEvaluationSlice",
+ "ModelExplanation",
+ "ModelMonitoringAlertConfig",
+ "ModelMonitoringObjectiveConfig",
+ "ModelMonitoringStatsAnomalies",
"ModelServiceClient",
+ "MutateDeployedIndexOperationMetadata",
+ "MutateDeployedIndexRequest",
+ "MutateDeployedIndexResponse",
+ "NearestNeighborSearchOperationMetadata",
+ "Neighbor",
+ "NfsMount",
+ "PauseModelDeploymentMonitoringJobRequest",
+ "PipelineFailurePolicy",
+ "PipelineJob",
+ "PipelineJobDetail",
"PipelineServiceClient",
"PipelineState",
+ "PipelineTaskDetail",
+ "PipelineTaskExecutorDetail",
+ "PipelineTemplateMetadata",
"Port",
"PredefinedSplit",
"PredictRequest",
+ "PredictRequestResponseLoggingConfig",
"PredictResponse",
"PredictSchemata",
"PredictionServiceClient",
+ "PrivateEndpoints",
+ "PurgeArtifactsMetadata",
+ "PurgeArtifactsRequest",
+ "PurgeArtifactsResponse",
+ "PurgeContextsMetadata",
+ "PurgeContextsRequest",
+ "PurgeContextsResponse",
+ "PurgeExecutionsMetadata",
+ "PurgeExecutionsRequest",
+ "PurgeExecutionsResponse",
"PythonPackageSpec",
+ "QueryArtifactLineageSubgraphRequest",
+ "QueryContextLineageSubgraphRequest",
+ "QueryExecutionInputsAndOutputsRequest",
+ "RawPredictRequest",
+ "ReadFeatureValuesRequest",
+ "ReadFeatureValuesResponse",
+ "ReadTensorboardBlobDataRequest",
+ "ReadTensorboardBlobDataResponse",
+ "ReadTensorboardTimeSeriesDataRequest",
+ "ReadTensorboardTimeSeriesDataResponse",
"ResourcesConsumed",
+ "ResumeModelDeploymentMonitoringJobRequest",
"SampleConfig",
+ "SampledShapleyAttribution",
+ "SamplingStrategy",
+ "Scalar",
"Scheduling",
+ "SearchFeaturesRequest",
+ "SearchFeaturesResponse",
"SearchMigratableResourcesRequest",
"SearchMigratableResourcesResponse",
+ "SearchModelDeploymentMonitoringStatsAnomaliesRequest",
+ "SearchModelDeploymentMonitoringStatsAnomaliesResponse",
+ "SmoothGradConfig",
"SpecialistPool",
"SpecialistPoolServiceClient",
+ "StopTrialRequest",
+ "StratifiedSplit",
+ "StreamingReadFeatureValuesRequest",
+ "StringArray",
+ "Study",
"StudySpec",
+ "SuggestTrialsMetadata",
+ "SuggestTrialsRequest",
+ "SuggestTrialsResponse",
+ "TFRecordDestination",
+ "Tensorboard",
+ "TensorboardBlob",
+ "TensorboardBlobSequence",
+ "TensorboardExperiment",
+ "TensorboardRun",
+ "TensorboardServiceClient",
+ "TensorboardTensor",
+ "TensorboardTimeSeries",
+ "ThresholdConfig",
+ "TimeSeriesData",
+ "TimeSeriesDataPoint",
"TimestampSplit",
"TrainingConfig",
"TrainingPipeline",
"Trial",
+ "UndeployIndexOperationMetadata",
+ "UndeployIndexRequest",
+ "UndeployIndexResponse",
"UndeployModelOperationMetadata",
"UndeployModelRequest",
"UndeployModelResponse",
+ "UnmanagedContainerModel",
+ "UpdateArtifactRequest",
+ "UpdateContextRequest",
"UpdateDatasetRequest",
"UpdateEndpointRequest",
+ "UpdateEntityTypeRequest",
+ "UpdateExecutionRequest",
+ "UpdateFeatureRequest",
+ "UpdateFeaturestoreOperationMetadata",
+ "UpdateFeaturestoreRequest",
+ "UpdateIndexEndpointRequest",
+ "UpdateIndexOperationMetadata",
+ "UpdateIndexRequest",
+ "UpdateModelDeploymentMonitoringJobOperationMetadata",
+ "UpdateModelDeploymentMonitoringJobRequest",
"UpdateModelRequest",
"UpdateSpecialistPoolOperationMetadata",
"UpdateSpecialistPoolRequest",
+ "UpdateTensorboardExperimentRequest",
+ "UpdateTensorboardOperationMetadata",
+ "UpdateTensorboardRequest",
+ "UpdateTensorboardRunRequest",
+ "UpdateTensorboardTimeSeriesRequest",
"UploadModelOperationMetadata",
"UploadModelRequest",
"UploadModelResponse",
"UserActionReference",
+ "Value",
+ "VizierServiceClient",
"WorkerPoolSpec",
+ "WriteTensorboardExperimentDataRequest",
+ "WriteTensorboardExperimentDataResponse",
+ "WriteTensorboardRunDataRequest",
+ "WriteTensorboardRunDataResponse",
+ "XraiAttribution",
)
diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json
index 0abed0fd70..35cf881fe2 100644
--- a/google/cloud/aiplatform_v1/gapic_metadata.json
+++ b/google/cloud/aiplatform_v1/gapic_metadata.json
@@ -203,514 +203,1892 @@
}
}
},
- "JobService": {
+ "FeaturestoreOnlineServingService": {
"clients": {
"grpc": {
- "libraryClient": "JobServiceClient",
+ "libraryClient": "FeaturestoreOnlineServingServiceClient",
"rpcs": {
- "CancelBatchPredictionJob": {
+ "ReadFeatureValues": {
"methods": [
- "cancel_batch_prediction_job"
+ "read_feature_values"
]
},
- "CancelCustomJob": {
+ "StreamingReadFeatureValues": {
"methods": [
- "cancel_custom_job"
+ "streaming_read_feature_values"
]
- },
- "CancelDataLabelingJob": {
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "FeaturestoreOnlineServingServiceAsyncClient",
+ "rpcs": {
+ "ReadFeatureValues": {
"methods": [
- "cancel_data_labeling_job"
+ "read_feature_values"
]
},
- "CancelHyperparameterTuningJob": {
+ "StreamingReadFeatureValues": {
"methods": [
- "cancel_hyperparameter_tuning_job"
+ "streaming_read_feature_values"
]
- },
- "CreateBatchPredictionJob": {
+ }
+ }
+ }
+ }
+ },
+ "FeaturestoreService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "FeaturestoreServiceClient",
+ "rpcs": {
+ "BatchCreateFeatures": {
"methods": [
- "create_batch_prediction_job"
+ "batch_create_features"
]
},
- "CreateCustomJob": {
+ "BatchReadFeatureValues": {
"methods": [
- "create_custom_job"
+ "batch_read_feature_values"
]
},
- "CreateDataLabelingJob": {
+ "CreateEntityType": {
"methods": [
- "create_data_labeling_job"
+ "create_entity_type"
]
},
- "CreateHyperparameterTuningJob": {
+ "CreateFeature": {
"methods": [
- "create_hyperparameter_tuning_job"
+ "create_feature"
]
},
- "DeleteBatchPredictionJob": {
+ "CreateFeaturestore": {
"methods": [
- "delete_batch_prediction_job"
+ "create_featurestore"
]
},
- "DeleteCustomJob": {
+ "DeleteEntityType": {
"methods": [
- "delete_custom_job"
+ "delete_entity_type"
]
},
- "DeleteDataLabelingJob": {
+ "DeleteFeature": {
"methods": [
- "delete_data_labeling_job"
+ "delete_feature"
]
},
- "DeleteHyperparameterTuningJob": {
+ "DeleteFeaturestore": {
"methods": [
- "delete_hyperparameter_tuning_job"
+ "delete_featurestore"
]
},
- "GetBatchPredictionJob": {
+ "ExportFeatureValues": {
"methods": [
- "get_batch_prediction_job"
+ "export_feature_values"
]
},
- "GetCustomJob": {
+ "GetEntityType": {
"methods": [
- "get_custom_job"
+ "get_entity_type"
]
},
- "GetDataLabelingJob": {
+ "GetFeature": {
"methods": [
- "get_data_labeling_job"
+ "get_feature"
]
},
- "GetHyperparameterTuningJob": {
+ "GetFeaturestore": {
"methods": [
- "get_hyperparameter_tuning_job"
+ "get_featurestore"
]
},
- "ListBatchPredictionJobs": {
+ "ImportFeatureValues": {
"methods": [
- "list_batch_prediction_jobs"
+ "import_feature_values"
]
},
- "ListCustomJobs": {
+ "ListEntityTypes": {
"methods": [
- "list_custom_jobs"
+ "list_entity_types"
]
},
- "ListDataLabelingJobs": {
+ "ListFeatures": {
"methods": [
- "list_data_labeling_jobs"
+ "list_features"
]
},
- "ListHyperparameterTuningJobs": {
+ "ListFeaturestores": {
"methods": [
- "list_hyperparameter_tuning_jobs"
+ "list_featurestores"
]
- }
- }
- },
- "grpc-async": {
- "libraryClient": "JobServiceAsyncClient",
- "rpcs": {
- "CancelBatchPredictionJob": {
+ },
+ "SearchFeatures": {
"methods": [
- "cancel_batch_prediction_job"
+ "search_features"
]
},
- "CancelCustomJob": {
+ "UpdateEntityType": {
"methods": [
- "cancel_custom_job"
+ "update_entity_type"
]
},
- "CancelDataLabelingJob": {
+ "UpdateFeature": {
"methods": [
- "cancel_data_labeling_job"
+ "update_feature"
]
},
- "CancelHyperparameterTuningJob": {
+ "UpdateFeaturestore": {
"methods": [
- "cancel_hyperparameter_tuning_job"
+ "update_featurestore"
]
- },
- "CreateBatchPredictionJob": {
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "FeaturestoreServiceAsyncClient",
+ "rpcs": {
+ "BatchCreateFeatures": {
"methods": [
- "create_batch_prediction_job"
+ "batch_create_features"
]
},
- "CreateCustomJob": {
+ "BatchReadFeatureValues": {
"methods": [
- "create_custom_job"
+ "batch_read_feature_values"
]
},
- "CreateDataLabelingJob": {
+ "CreateEntityType": {
"methods": [
- "create_data_labeling_job"
+ "create_entity_type"
]
},
- "CreateHyperparameterTuningJob": {
+ "CreateFeature": {
"methods": [
- "create_hyperparameter_tuning_job"
+ "create_feature"
]
},
- "DeleteBatchPredictionJob": {
+ "CreateFeaturestore": {
"methods": [
- "delete_batch_prediction_job"
+ "create_featurestore"
]
},
- "DeleteCustomJob": {
+ "DeleteEntityType": {
"methods": [
- "delete_custom_job"
+ "delete_entity_type"
]
},
- "DeleteDataLabelingJob": {
+ "DeleteFeature": {
"methods": [
- "delete_data_labeling_job"
+ "delete_feature"
]
},
- "DeleteHyperparameterTuningJob": {
+ "DeleteFeaturestore": {
"methods": [
- "delete_hyperparameter_tuning_job"
+ "delete_featurestore"
]
},
- "GetBatchPredictionJob": {
+ "ExportFeatureValues": {
"methods": [
- "get_batch_prediction_job"
+ "export_feature_values"
]
},
- "GetCustomJob": {
+ "GetEntityType": {
"methods": [
- "get_custom_job"
+ "get_entity_type"
]
},
- "GetDataLabelingJob": {
+ "GetFeature": {
"methods": [
- "get_data_labeling_job"
+ "get_feature"
]
},
- "GetHyperparameterTuningJob": {
+ "GetFeaturestore": {
"methods": [
- "get_hyperparameter_tuning_job"
+ "get_featurestore"
]
},
- "ListBatchPredictionJobs": {
+ "ImportFeatureValues": {
"methods": [
- "list_batch_prediction_jobs"
+ "import_feature_values"
]
},
- "ListCustomJobs": {
+ "ListEntityTypes": {
"methods": [
- "list_custom_jobs"
+ "list_entity_types"
]
},
- "ListDataLabelingJobs": {
+ "ListFeatures": {
"methods": [
- "list_data_labeling_jobs"
+ "list_features"
]
},
- "ListHyperparameterTuningJobs": {
+ "ListFeaturestores": {
"methods": [
- "list_hyperparameter_tuning_jobs"
+ "list_featurestores"
]
- }
- }
- }
- }
- },
- "MigrationService": {
- "clients": {
- "grpc": {
- "libraryClient": "MigrationServiceClient",
- "rpcs": {
- "BatchMigrateResources": {
+ },
+ "SearchFeatures": {
"methods": [
- "batch_migrate_resources"
+ "search_features"
]
},
- "SearchMigratableResources": {
+ "UpdateEntityType": {
"methods": [
- "search_migratable_resources"
+ "update_entity_type"
]
- }
- }
- },
- "grpc-async": {
- "libraryClient": "MigrationServiceAsyncClient",
- "rpcs": {
- "BatchMigrateResources": {
+ },
+ "UpdateFeature": {
"methods": [
- "batch_migrate_resources"
+ "update_feature"
]
},
- "SearchMigratableResources": {
+ "UpdateFeaturestore": {
"methods": [
- "search_migratable_resources"
+ "update_featurestore"
]
}
}
}
}
},
- "ModelService": {
+ "IndexEndpointService": {
"clients": {
"grpc": {
- "libraryClient": "ModelServiceClient",
+ "libraryClient": "IndexEndpointServiceClient",
"rpcs": {
- "DeleteModel": {
- "methods": [
- "delete_model"
- ]
- },
- "ExportModel": {
+ "CreateIndexEndpoint": {
"methods": [
- "export_model"
- ]
- },
- "GetModel": {
- "methods": [
- "get_model"
+ "create_index_endpoint"
]
},
- "GetModelEvaluation": {
+ "DeleteIndexEndpoint": {
"methods": [
- "get_model_evaluation"
+ "delete_index_endpoint"
]
},
- "GetModelEvaluationSlice": {
+ "DeployIndex": {
"methods": [
- "get_model_evaluation_slice"
+ "deploy_index"
]
},
- "ListModelEvaluationSlices": {
+ "GetIndexEndpoint": {
"methods": [
- "list_model_evaluation_slices"
+ "get_index_endpoint"
]
},
- "ListModelEvaluations": {
+ "ListIndexEndpoints": {
"methods": [
- "list_model_evaluations"
+ "list_index_endpoints"
]
},
- "ListModels": {
+ "MutateDeployedIndex": {
"methods": [
- "list_models"
+ "mutate_deployed_index"
]
},
- "UpdateModel": {
+ "UndeployIndex": {
"methods": [
- "update_model"
+ "undeploy_index"
]
},
- "UploadModel": {
+ "UpdateIndexEndpoint": {
"methods": [
- "upload_model"
+ "update_index_endpoint"
]
}
}
},
"grpc-async": {
- "libraryClient": "ModelServiceAsyncClient",
+ "libraryClient": "IndexEndpointServiceAsyncClient",
"rpcs": {
- "DeleteModel": {
- "methods": [
- "delete_model"
- ]
- },
- "ExportModel": {
- "methods": [
- "export_model"
- ]
- },
- "GetModel": {
+ "CreateIndexEndpoint": {
"methods": [
- "get_model"
+ "create_index_endpoint"
]
},
- "GetModelEvaluation": {
+ "DeleteIndexEndpoint": {
"methods": [
- "get_model_evaluation"
+ "delete_index_endpoint"
]
},
- "GetModelEvaluationSlice": {
+ "DeployIndex": {
"methods": [
- "get_model_evaluation_slice"
+ "deploy_index"
]
},
- "ListModelEvaluationSlices": {
+ "GetIndexEndpoint": {
"methods": [
- "list_model_evaluation_slices"
+ "get_index_endpoint"
]
},
- "ListModelEvaluations": {
+ "ListIndexEndpoints": {
"methods": [
- "list_model_evaluations"
+ "list_index_endpoints"
]
},
- "ListModels": {
+ "MutateDeployedIndex": {
"methods": [
- "list_models"
+ "mutate_deployed_index"
]
},
- "UpdateModel": {
+ "UndeployIndex": {
"methods": [
- "update_model"
+ "undeploy_index"
]
},
- "UploadModel": {
+ "UpdateIndexEndpoint": {
"methods": [
- "upload_model"
+ "update_index_endpoint"
]
}
}
}
}
},
- "PipelineService": {
+ "IndexService": {
"clients": {
"grpc": {
- "libraryClient": "PipelineServiceClient",
+ "libraryClient": "IndexServiceClient",
"rpcs": {
- "CancelTrainingPipeline": {
+ "CreateIndex": {
"methods": [
- "cancel_training_pipeline"
+ "create_index"
]
},
- "CreateTrainingPipeline": {
+ "DeleteIndex": {
"methods": [
- "create_training_pipeline"
+ "delete_index"
]
},
- "DeleteTrainingPipeline": {
+ "GetIndex": {
"methods": [
- "delete_training_pipeline"
+ "get_index"
]
},
- "GetTrainingPipeline": {
+ "ListIndexes": {
"methods": [
- "get_training_pipeline"
+ "list_indexes"
]
},
- "ListTrainingPipelines": {
+ "UpdateIndex": {
"methods": [
- "list_training_pipelines"
+ "update_index"
]
}
}
},
"grpc-async": {
- "libraryClient": "PipelineServiceAsyncClient",
+ "libraryClient": "IndexServiceAsyncClient",
"rpcs": {
- "CancelTrainingPipeline": {
+ "CreateIndex": {
"methods": [
- "cancel_training_pipeline"
+ "create_index"
]
},
- "CreateTrainingPipeline": {
+ "DeleteIndex": {
"methods": [
- "create_training_pipeline"
+ "delete_index"
]
},
- "DeleteTrainingPipeline": {
+ "GetIndex": {
"methods": [
- "delete_training_pipeline"
+ "get_index"
]
},
- "GetTrainingPipeline": {
+ "ListIndexes": {
"methods": [
- "get_training_pipeline"
+ "list_indexes"
]
},
- "ListTrainingPipelines": {
+ "UpdateIndex": {
"methods": [
- "list_training_pipelines"
+ "update_index"
]
}
}
}
}
},
- "PredictionService": {
+ "JobService": {
"clients": {
"grpc": {
- "libraryClient": "PredictionServiceClient",
+ "libraryClient": "JobServiceClient",
"rpcs": {
- "Predict": {
+ "CancelBatchPredictionJob": {
"methods": [
- "predict"
+ "cancel_batch_prediction_job"
]
- }
- }
- },
- "grpc-async": {
- "libraryClient": "PredictionServiceAsyncClient",
- "rpcs": {
- "Predict": {
+ },
+ "CancelCustomJob": {
"methods": [
- "predict"
+ "cancel_custom_job"
]
- }
- }
- }
- }
- },
- "SpecialistPoolService": {
- "clients": {
- "grpc": {
- "libraryClient": "SpecialistPoolServiceClient",
- "rpcs": {
- "CreateSpecialistPool": {
+ },
+ "CancelDataLabelingJob": {
"methods": [
- "create_specialist_pool"
+ "cancel_data_labeling_job"
]
},
- "DeleteSpecialistPool": {
+ "CancelHyperparameterTuningJob": {
"methods": [
- "delete_specialist_pool"
+ "cancel_hyperparameter_tuning_job"
]
},
- "GetSpecialistPool": {
+ "CreateBatchPredictionJob": {
"methods": [
- "get_specialist_pool"
+ "create_batch_prediction_job"
+ ]
+ },
+ "CreateCustomJob": {
+ "methods": [
+ "create_custom_job"
+ ]
+ },
+ "CreateDataLabelingJob": {
+ "methods": [
+ "create_data_labeling_job"
+ ]
+ },
+ "CreateHyperparameterTuningJob": {
+ "methods": [
+ "create_hyperparameter_tuning_job"
+ ]
+ },
+ "CreateModelDeploymentMonitoringJob": {
+ "methods": [
+ "create_model_deployment_monitoring_job"
+ ]
+ },
+ "DeleteBatchPredictionJob": {
+ "methods": [
+ "delete_batch_prediction_job"
+ ]
+ },
+ "DeleteCustomJob": {
+ "methods": [
+ "delete_custom_job"
+ ]
+ },
+ "DeleteDataLabelingJob": {
+ "methods": [
+ "delete_data_labeling_job"
+ ]
+ },
+ "DeleteHyperparameterTuningJob": {
+ "methods": [
+ "delete_hyperparameter_tuning_job"
+ ]
+ },
+ "DeleteModelDeploymentMonitoringJob": {
+ "methods": [
+ "delete_model_deployment_monitoring_job"
+ ]
+ },
+ "GetBatchPredictionJob": {
+ "methods": [
+ "get_batch_prediction_job"
+ ]
+ },
+ "GetCustomJob": {
+ "methods": [
+ "get_custom_job"
+ ]
+ },
+ "GetDataLabelingJob": {
+ "methods": [
+ "get_data_labeling_job"
+ ]
+ },
+ "GetHyperparameterTuningJob": {
+ "methods": [
+ "get_hyperparameter_tuning_job"
+ ]
+ },
+ "GetModelDeploymentMonitoringJob": {
+ "methods": [
+ "get_model_deployment_monitoring_job"
+ ]
+ },
+ "ListBatchPredictionJobs": {
+ "methods": [
+ "list_batch_prediction_jobs"
+ ]
+ },
+ "ListCustomJobs": {
+ "methods": [
+ "list_custom_jobs"
+ ]
+ },
+ "ListDataLabelingJobs": {
+ "methods": [
+ "list_data_labeling_jobs"
+ ]
+ },
+ "ListHyperparameterTuningJobs": {
+ "methods": [
+ "list_hyperparameter_tuning_jobs"
+ ]
+ },
+ "ListModelDeploymentMonitoringJobs": {
+ "methods": [
+ "list_model_deployment_monitoring_jobs"
+ ]
+ },
+ "PauseModelDeploymentMonitoringJob": {
+ "methods": [
+ "pause_model_deployment_monitoring_job"
+ ]
+ },
+ "ResumeModelDeploymentMonitoringJob": {
+ "methods": [
+ "resume_model_deployment_monitoring_job"
+ ]
+ },
+ "SearchModelDeploymentMonitoringStatsAnomalies": {
+ "methods": [
+ "search_model_deployment_monitoring_stats_anomalies"
+ ]
+ },
+ "UpdateModelDeploymentMonitoringJob": {
+ "methods": [
+ "update_model_deployment_monitoring_job"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "JobServiceAsyncClient",
+ "rpcs": {
+ "CancelBatchPredictionJob": {
+ "methods": [
+ "cancel_batch_prediction_job"
+ ]
+ },
+ "CancelCustomJob": {
+ "methods": [
+ "cancel_custom_job"
+ ]
+ },
+ "CancelDataLabelingJob": {
+ "methods": [
+ "cancel_data_labeling_job"
+ ]
+ },
+ "CancelHyperparameterTuningJob": {
+ "methods": [
+ "cancel_hyperparameter_tuning_job"
+ ]
+ },
+ "CreateBatchPredictionJob": {
+ "methods": [
+ "create_batch_prediction_job"
+ ]
+ },
+ "CreateCustomJob": {
+ "methods": [
+ "create_custom_job"
+ ]
+ },
+ "CreateDataLabelingJob": {
+ "methods": [
+ "create_data_labeling_job"
+ ]
+ },
+ "CreateHyperparameterTuningJob": {
+ "methods": [
+ "create_hyperparameter_tuning_job"
+ ]
+ },
+ "CreateModelDeploymentMonitoringJob": {
+ "methods": [
+ "create_model_deployment_monitoring_job"
+ ]
+ },
+ "DeleteBatchPredictionJob": {
+ "methods": [
+ "delete_batch_prediction_job"
+ ]
+ },
+ "DeleteCustomJob": {
+ "methods": [
+ "delete_custom_job"
+ ]
+ },
+ "DeleteDataLabelingJob": {
+ "methods": [
+ "delete_data_labeling_job"
+ ]
+ },
+ "DeleteHyperparameterTuningJob": {
+ "methods": [
+ "delete_hyperparameter_tuning_job"
+ ]
+ },
+ "DeleteModelDeploymentMonitoringJob": {
+ "methods": [
+ "delete_model_deployment_monitoring_job"
+ ]
+ },
+ "GetBatchPredictionJob": {
+ "methods": [
+ "get_batch_prediction_job"
+ ]
+ },
+ "GetCustomJob": {
+ "methods": [
+ "get_custom_job"
+ ]
+ },
+ "GetDataLabelingJob": {
+ "methods": [
+ "get_data_labeling_job"
+ ]
+ },
+ "GetHyperparameterTuningJob": {
+ "methods": [
+ "get_hyperparameter_tuning_job"
+ ]
+ },
+ "GetModelDeploymentMonitoringJob": {
+ "methods": [
+ "get_model_deployment_monitoring_job"
+ ]
+ },
+ "ListBatchPredictionJobs": {
+ "methods": [
+ "list_batch_prediction_jobs"
+ ]
+ },
+ "ListCustomJobs": {
+ "methods": [
+ "list_custom_jobs"
+ ]
+ },
+ "ListDataLabelingJobs": {
+ "methods": [
+ "list_data_labeling_jobs"
+ ]
+ },
+ "ListHyperparameterTuningJobs": {
+ "methods": [
+ "list_hyperparameter_tuning_jobs"
+ ]
+ },
+ "ListModelDeploymentMonitoringJobs": {
+ "methods": [
+ "list_model_deployment_monitoring_jobs"
+ ]
+ },
+ "PauseModelDeploymentMonitoringJob": {
+ "methods": [
+ "pause_model_deployment_monitoring_job"
+ ]
+ },
+ "ResumeModelDeploymentMonitoringJob": {
+ "methods": [
+ "resume_model_deployment_monitoring_job"
+ ]
+ },
+ "SearchModelDeploymentMonitoringStatsAnomalies": {
+ "methods": [
+ "search_model_deployment_monitoring_stats_anomalies"
+ ]
+ },
+ "UpdateModelDeploymentMonitoringJob": {
+ "methods": [
+ "update_model_deployment_monitoring_job"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "MetadataService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "MetadataServiceClient",
+ "rpcs": {
+ "AddContextArtifactsAndExecutions": {
+ "methods": [
+ "add_context_artifacts_and_executions"
+ ]
+ },
+ "AddContextChildren": {
+ "methods": [
+ "add_context_children"
+ ]
+ },
+ "AddExecutionEvents": {
+ "methods": [
+ "add_execution_events"
+ ]
+ },
+ "CreateArtifact": {
+ "methods": [
+ "create_artifact"
+ ]
+ },
+ "CreateContext": {
+ "methods": [
+ "create_context"
+ ]
+ },
+ "CreateExecution": {
+ "methods": [
+ "create_execution"
+ ]
+ },
+ "CreateMetadataSchema": {
+ "methods": [
+ "create_metadata_schema"
+ ]
+ },
+ "CreateMetadataStore": {
+ "methods": [
+ "create_metadata_store"
+ ]
+ },
+ "DeleteArtifact": {
+ "methods": [
+ "delete_artifact"
+ ]
+ },
+ "DeleteContext": {
+ "methods": [
+ "delete_context"
+ ]
+ },
+ "DeleteExecution": {
+ "methods": [
+ "delete_execution"
+ ]
+ },
+ "DeleteMetadataStore": {
+ "methods": [
+ "delete_metadata_store"
+ ]
+ },
+ "GetArtifact": {
+ "methods": [
+ "get_artifact"
+ ]
+ },
+ "GetContext": {
+ "methods": [
+ "get_context"
+ ]
+ },
+ "GetExecution": {
+ "methods": [
+ "get_execution"
+ ]
+ },
+ "GetMetadataSchema": {
+ "methods": [
+ "get_metadata_schema"
+ ]
+ },
+ "GetMetadataStore": {
+ "methods": [
+ "get_metadata_store"
+ ]
+ },
+ "ListArtifacts": {
+ "methods": [
+ "list_artifacts"
+ ]
+ },
+ "ListContexts": {
+ "methods": [
+ "list_contexts"
+ ]
+ },
+ "ListExecutions": {
+ "methods": [
+ "list_executions"
+ ]
+ },
+ "ListMetadataSchemas": {
+ "methods": [
+ "list_metadata_schemas"
+ ]
+ },
+ "ListMetadataStores": {
+ "methods": [
+ "list_metadata_stores"
+ ]
+ },
+ "PurgeArtifacts": {
+ "methods": [
+ "purge_artifacts"
+ ]
+ },
+ "PurgeContexts": {
+ "methods": [
+ "purge_contexts"
+ ]
+ },
+ "PurgeExecutions": {
+ "methods": [
+ "purge_executions"
+ ]
+ },
+ "QueryArtifactLineageSubgraph": {
+ "methods": [
+ "query_artifact_lineage_subgraph"
+ ]
+ },
+ "QueryContextLineageSubgraph": {
+ "methods": [
+ "query_context_lineage_subgraph"
+ ]
+ },
+ "QueryExecutionInputsAndOutputs": {
+ "methods": [
+ "query_execution_inputs_and_outputs"
+ ]
+ },
+ "UpdateArtifact": {
+ "methods": [
+ "update_artifact"
+ ]
+ },
+ "UpdateContext": {
+ "methods": [
+ "update_context"
+ ]
+ },
+ "UpdateExecution": {
+ "methods": [
+ "update_execution"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "MetadataServiceAsyncClient",
+ "rpcs": {
+ "AddContextArtifactsAndExecutions": {
+ "methods": [
+ "add_context_artifacts_and_executions"
+ ]
+ },
+ "AddContextChildren": {
+ "methods": [
+ "add_context_children"
+ ]
+ },
+ "AddExecutionEvents": {
+ "methods": [
+ "add_execution_events"
+ ]
+ },
+ "CreateArtifact": {
+ "methods": [
+ "create_artifact"
+ ]
+ },
+ "CreateContext": {
+ "methods": [
+ "create_context"
+ ]
+ },
+ "CreateExecution": {
+ "methods": [
+ "create_execution"
+ ]
+ },
+ "CreateMetadataSchema": {
+ "methods": [
+ "create_metadata_schema"
+ ]
+ },
+ "CreateMetadataStore": {
+ "methods": [
+ "create_metadata_store"
+ ]
+ },
+ "DeleteArtifact": {
+ "methods": [
+ "delete_artifact"
+ ]
+ },
+ "DeleteContext": {
+ "methods": [
+ "delete_context"
+ ]
+ },
+ "DeleteExecution": {
+ "methods": [
+ "delete_execution"
+ ]
+ },
+ "DeleteMetadataStore": {
+ "methods": [
+ "delete_metadata_store"
+ ]
+ },
+ "GetArtifact": {
+ "methods": [
+ "get_artifact"
+ ]
+ },
+ "GetContext": {
+ "methods": [
+ "get_context"
+ ]
+ },
+ "GetExecution": {
+ "methods": [
+ "get_execution"
+ ]
+ },
+ "GetMetadataSchema": {
+ "methods": [
+ "get_metadata_schema"
+ ]
+ },
+ "GetMetadataStore": {
+ "methods": [
+ "get_metadata_store"
+ ]
+ },
+ "ListArtifacts": {
+ "methods": [
+ "list_artifacts"
+ ]
+ },
+ "ListContexts": {
+ "methods": [
+ "list_contexts"
+ ]
+ },
+ "ListExecutions": {
+ "methods": [
+ "list_executions"
+ ]
+ },
+ "ListMetadataSchemas": {
+ "methods": [
+ "list_metadata_schemas"
+ ]
+ },
+ "ListMetadataStores": {
+ "methods": [
+ "list_metadata_stores"
+ ]
+ },
+ "PurgeArtifacts": {
+ "methods": [
+ "purge_artifacts"
+ ]
+ },
+ "PurgeContexts": {
+ "methods": [
+ "purge_contexts"
+ ]
+ },
+ "PurgeExecutions": {
+ "methods": [
+ "purge_executions"
+ ]
+ },
+ "QueryArtifactLineageSubgraph": {
+ "methods": [
+ "query_artifact_lineage_subgraph"
+ ]
+ },
+ "QueryContextLineageSubgraph": {
+ "methods": [
+ "query_context_lineage_subgraph"
+ ]
+ },
+ "QueryExecutionInputsAndOutputs": {
+ "methods": [
+ "query_execution_inputs_and_outputs"
+ ]
+ },
+ "UpdateArtifact": {
+ "methods": [
+ "update_artifact"
+ ]
+ },
+ "UpdateContext": {
+ "methods": [
+ "update_context"
+ ]
+ },
+ "UpdateExecution": {
+ "methods": [
+ "update_execution"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "MigrationService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "MigrationServiceClient",
+ "rpcs": {
+ "BatchMigrateResources": {
+ "methods": [
+ "batch_migrate_resources"
+ ]
+ },
+ "SearchMigratableResources": {
+ "methods": [
+ "search_migratable_resources"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "MigrationServiceAsyncClient",
+ "rpcs": {
+ "BatchMigrateResources": {
+ "methods": [
+ "batch_migrate_resources"
+ ]
+ },
+ "SearchMigratableResources": {
+ "methods": [
+ "search_migratable_resources"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "ModelService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "ModelServiceClient",
+ "rpcs": {
+ "DeleteModel": {
+ "methods": [
+ "delete_model"
+ ]
+ },
+ "DeleteModelVersion": {
+ "methods": [
+ "delete_model_version"
+ ]
+ },
+ "ExportModel": {
+ "methods": [
+ "export_model"
+ ]
+ },
+ "GetModel": {
+ "methods": [
+ "get_model"
+ ]
+ },
+ "GetModelEvaluation": {
+ "methods": [
+ "get_model_evaluation"
+ ]
+ },
+ "GetModelEvaluationSlice": {
+ "methods": [
+ "get_model_evaluation_slice"
+ ]
+ },
+ "ImportModelEvaluation": {
+ "methods": [
+ "import_model_evaluation"
+ ]
+ },
+ "ListModelEvaluationSlices": {
+ "methods": [
+ "list_model_evaluation_slices"
+ ]
+ },
+ "ListModelEvaluations": {
+ "methods": [
+ "list_model_evaluations"
+ ]
+ },
+ "ListModelVersions": {
+ "methods": [
+ "list_model_versions"
+ ]
+ },
+ "ListModels": {
+ "methods": [
+ "list_models"
+ ]
+ },
+ "MergeVersionAliases": {
+ "methods": [
+ "merge_version_aliases"
+ ]
+ },
+ "UpdateModel": {
+ "methods": [
+ "update_model"
+ ]
+ },
+ "UploadModel": {
+ "methods": [
+ "upload_model"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "ModelServiceAsyncClient",
+ "rpcs": {
+ "DeleteModel": {
+ "methods": [
+ "delete_model"
+ ]
+ },
+ "DeleteModelVersion": {
+ "methods": [
+ "delete_model_version"
+ ]
+ },
+ "ExportModel": {
+ "methods": [
+ "export_model"
+ ]
+ },
+ "GetModel": {
+ "methods": [
+ "get_model"
+ ]
+ },
+ "GetModelEvaluation": {
+ "methods": [
+ "get_model_evaluation"
+ ]
+ },
+ "GetModelEvaluationSlice": {
+ "methods": [
+ "get_model_evaluation_slice"
+ ]
+ },
+ "ImportModelEvaluation": {
+ "methods": [
+ "import_model_evaluation"
+ ]
+ },
+ "ListModelEvaluationSlices": {
+ "methods": [
+ "list_model_evaluation_slices"
+ ]
+ },
+ "ListModelEvaluations": {
+ "methods": [
+ "list_model_evaluations"
+ ]
+ },
+ "ListModelVersions": {
+ "methods": [
+ "list_model_versions"
+ ]
+ },
+ "ListModels": {
+ "methods": [
+ "list_models"
+ ]
+ },
+ "MergeVersionAliases": {
+ "methods": [
+ "merge_version_aliases"
+ ]
+ },
+ "UpdateModel": {
+ "methods": [
+ "update_model"
+ ]
+ },
+ "UploadModel": {
+ "methods": [
+ "upload_model"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "PipelineService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "PipelineServiceClient",
+ "rpcs": {
+ "CancelPipelineJob": {
+ "methods": [
+ "cancel_pipeline_job"
+ ]
+ },
+ "CancelTrainingPipeline": {
+ "methods": [
+ "cancel_training_pipeline"
+ ]
+ },
+ "CreatePipelineJob": {
+ "methods": [
+ "create_pipeline_job"
+ ]
+ },
+ "CreateTrainingPipeline": {
+ "methods": [
+ "create_training_pipeline"
+ ]
+ },
+ "DeletePipelineJob": {
+ "methods": [
+ "delete_pipeline_job"
+ ]
+ },
+ "DeleteTrainingPipeline": {
+ "methods": [
+ "delete_training_pipeline"
+ ]
+ },
+ "GetPipelineJob": {
+ "methods": [
+ "get_pipeline_job"
+ ]
+ },
+ "GetTrainingPipeline": {
+ "methods": [
+ "get_training_pipeline"
+ ]
+ },
+ "ListPipelineJobs": {
+ "methods": [
+ "list_pipeline_jobs"
+ ]
+ },
+ "ListTrainingPipelines": {
+ "methods": [
+ "list_training_pipelines"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "PipelineServiceAsyncClient",
+ "rpcs": {
+ "CancelPipelineJob": {
+ "methods": [
+ "cancel_pipeline_job"
+ ]
+ },
+ "CancelTrainingPipeline": {
+ "methods": [
+ "cancel_training_pipeline"
+ ]
+ },
+ "CreatePipelineJob": {
+ "methods": [
+ "create_pipeline_job"
+ ]
+ },
+ "CreateTrainingPipeline": {
+ "methods": [
+ "create_training_pipeline"
+ ]
+ },
+ "DeletePipelineJob": {
+ "methods": [
+ "delete_pipeline_job"
+ ]
+ },
+ "DeleteTrainingPipeline": {
+ "methods": [
+ "delete_training_pipeline"
+ ]
+ },
+ "GetPipelineJob": {
+ "methods": [
+ "get_pipeline_job"
+ ]
+ },
+ "GetTrainingPipeline": {
+ "methods": [
+ "get_training_pipeline"
+ ]
+ },
+ "ListPipelineJobs": {
+ "methods": [
+ "list_pipeline_jobs"
+ ]
+ },
+ "ListTrainingPipelines": {
+ "methods": [
+ "list_training_pipelines"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "PredictionService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "PredictionServiceClient",
+ "rpcs": {
+ "Explain": {
+ "methods": [
+ "explain"
+ ]
+ },
+ "Predict": {
+ "methods": [
+ "predict"
+ ]
+ },
+ "RawPredict": {
+ "methods": [
+ "raw_predict"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "PredictionServiceAsyncClient",
+ "rpcs": {
+ "Explain": {
+ "methods": [
+ "explain"
+ ]
+ },
+ "Predict": {
+ "methods": [
+ "predict"
+ ]
+ },
+ "RawPredict": {
+ "methods": [
+ "raw_predict"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "SpecialistPoolService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "SpecialistPoolServiceClient",
+ "rpcs": {
+ "CreateSpecialistPool": {
+ "methods": [
+ "create_specialist_pool"
+ ]
+ },
+ "DeleteSpecialistPool": {
+ "methods": [
+ "delete_specialist_pool"
+ ]
+ },
+ "GetSpecialistPool": {
+ "methods": [
+ "get_specialist_pool"
+ ]
+ },
+ "ListSpecialistPools": {
+ "methods": [
+ "list_specialist_pools"
+ ]
+ },
+ "UpdateSpecialistPool": {
+ "methods": [
+ "update_specialist_pool"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "SpecialistPoolServiceAsyncClient",
+ "rpcs": {
+ "CreateSpecialistPool": {
+ "methods": [
+ "create_specialist_pool"
+ ]
+ },
+ "DeleteSpecialistPool": {
+ "methods": [
+ "delete_specialist_pool"
+ ]
+ },
+ "GetSpecialistPool": {
+ "methods": [
+ "get_specialist_pool"
]
},
"ListSpecialistPools": {
"methods": [
- "list_specialist_pools"
+ "list_specialist_pools"
+ ]
+ },
+ "UpdateSpecialistPool": {
+ "methods": [
+ "update_specialist_pool"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "TensorboardService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "TensorboardServiceClient",
+ "rpcs": {
+ "BatchCreateTensorboardRuns": {
+ "methods": [
+ "batch_create_tensorboard_runs"
+ ]
+ },
+ "BatchCreateTensorboardTimeSeries": {
+ "methods": [
+ "batch_create_tensorboard_time_series"
+ ]
+ },
+ "BatchReadTensorboardTimeSeriesData": {
+ "methods": [
+ "batch_read_tensorboard_time_series_data"
]
},
- "UpdateSpecialistPool": {
+ "CreateTensorboard": {
"methods": [
- "update_specialist_pool"
+ "create_tensorboard"
+ ]
+ },
+ "CreateTensorboardExperiment": {
+ "methods": [
+ "create_tensorboard_experiment"
+ ]
+ },
+ "CreateTensorboardRun": {
+ "methods": [
+ "create_tensorboard_run"
+ ]
+ },
+ "CreateTensorboardTimeSeries": {
+ "methods": [
+ "create_tensorboard_time_series"
+ ]
+ },
+ "DeleteTensorboard": {
+ "methods": [
+ "delete_tensorboard"
+ ]
+ },
+ "DeleteTensorboardExperiment": {
+ "methods": [
+ "delete_tensorboard_experiment"
+ ]
+ },
+ "DeleteTensorboardRun": {
+ "methods": [
+ "delete_tensorboard_run"
+ ]
+ },
+ "DeleteTensorboardTimeSeries": {
+ "methods": [
+ "delete_tensorboard_time_series"
+ ]
+ },
+ "ExportTensorboardTimeSeriesData": {
+ "methods": [
+ "export_tensorboard_time_series_data"
+ ]
+ },
+ "GetTensorboard": {
+ "methods": [
+ "get_tensorboard"
+ ]
+ },
+ "GetTensorboardExperiment": {
+ "methods": [
+ "get_tensorboard_experiment"
+ ]
+ },
+ "GetTensorboardRun": {
+ "methods": [
+ "get_tensorboard_run"
+ ]
+ },
+ "GetTensorboardTimeSeries": {
+ "methods": [
+ "get_tensorboard_time_series"
+ ]
+ },
+ "ListTensorboardExperiments": {
+ "methods": [
+ "list_tensorboard_experiments"
+ ]
+ },
+ "ListTensorboardRuns": {
+ "methods": [
+ "list_tensorboard_runs"
+ ]
+ },
+ "ListTensorboardTimeSeries": {
+ "methods": [
+ "list_tensorboard_time_series"
+ ]
+ },
+ "ListTensorboards": {
+ "methods": [
+ "list_tensorboards"
+ ]
+ },
+ "ReadTensorboardBlobData": {
+ "methods": [
+ "read_tensorboard_blob_data"
+ ]
+ },
+ "ReadTensorboardTimeSeriesData": {
+ "methods": [
+ "read_tensorboard_time_series_data"
+ ]
+ },
+ "UpdateTensorboard": {
+ "methods": [
+ "update_tensorboard"
+ ]
+ },
+ "UpdateTensorboardExperiment": {
+ "methods": [
+ "update_tensorboard_experiment"
+ ]
+ },
+ "UpdateTensorboardRun": {
+ "methods": [
+ "update_tensorboard_run"
+ ]
+ },
+ "UpdateTensorboardTimeSeries": {
+ "methods": [
+ "update_tensorboard_time_series"
+ ]
+ },
+ "WriteTensorboardExperimentData": {
+ "methods": [
+ "write_tensorboard_experiment_data"
+ ]
+ },
+ "WriteTensorboardRunData": {
+ "methods": [
+ "write_tensorboard_run_data"
]
}
}
},
"grpc-async": {
- "libraryClient": "SpecialistPoolServiceAsyncClient",
+ "libraryClient": "TensorboardServiceAsyncClient",
"rpcs": {
- "CreateSpecialistPool": {
+ "BatchCreateTensorboardRuns": {
"methods": [
- "create_specialist_pool"
+ "batch_create_tensorboard_runs"
]
},
- "DeleteSpecialistPool": {
+ "BatchCreateTensorboardTimeSeries": {
"methods": [
- "delete_specialist_pool"
+ "batch_create_tensorboard_time_series"
]
},
- "GetSpecialistPool": {
+ "BatchReadTensorboardTimeSeriesData": {
"methods": [
- "get_specialist_pool"
+ "batch_read_tensorboard_time_series_data"
]
},
- "ListSpecialistPools": {
+ "CreateTensorboard": {
"methods": [
- "list_specialist_pools"
+ "create_tensorboard"
]
},
- "UpdateSpecialistPool": {
+ "CreateTensorboardExperiment": {
"methods": [
- "update_specialist_pool"
+ "create_tensorboard_experiment"
+ ]
+ },
+ "CreateTensorboardRun": {
+ "methods": [
+ "create_tensorboard_run"
+ ]
+ },
+ "CreateTensorboardTimeSeries": {
+ "methods": [
+ "create_tensorboard_time_series"
+ ]
+ },
+ "DeleteTensorboard": {
+ "methods": [
+ "delete_tensorboard"
+ ]
+ },
+ "DeleteTensorboardExperiment": {
+ "methods": [
+ "delete_tensorboard_experiment"
+ ]
+ },
+ "DeleteTensorboardRun": {
+ "methods": [
+ "delete_tensorboard_run"
+ ]
+ },
+ "DeleteTensorboardTimeSeries": {
+ "methods": [
+ "delete_tensorboard_time_series"
+ ]
+ },
+ "ExportTensorboardTimeSeriesData": {
+ "methods": [
+ "export_tensorboard_time_series_data"
+ ]
+ },
+ "GetTensorboard": {
+ "methods": [
+ "get_tensorboard"
+ ]
+ },
+ "GetTensorboardExperiment": {
+ "methods": [
+ "get_tensorboard_experiment"
+ ]
+ },
+ "GetTensorboardRun": {
+ "methods": [
+ "get_tensorboard_run"
+ ]
+ },
+ "GetTensorboardTimeSeries": {
+ "methods": [
+ "get_tensorboard_time_series"
+ ]
+ },
+ "ListTensorboardExperiments": {
+ "methods": [
+ "list_tensorboard_experiments"
+ ]
+ },
+ "ListTensorboardRuns": {
+ "methods": [
+ "list_tensorboard_runs"
+ ]
+ },
+ "ListTensorboardTimeSeries": {
+ "methods": [
+ "list_tensorboard_time_series"
+ ]
+ },
+ "ListTensorboards": {
+ "methods": [
+ "list_tensorboards"
+ ]
+ },
+ "ReadTensorboardBlobData": {
+ "methods": [
+ "read_tensorboard_blob_data"
+ ]
+ },
+ "ReadTensorboardTimeSeriesData": {
+ "methods": [
+ "read_tensorboard_time_series_data"
+ ]
+ },
+ "UpdateTensorboard": {
+ "methods": [
+ "update_tensorboard"
+ ]
+ },
+ "UpdateTensorboardExperiment": {
+ "methods": [
+ "update_tensorboard_experiment"
+ ]
+ },
+ "UpdateTensorboardRun": {
+ "methods": [
+ "update_tensorboard_run"
+ ]
+ },
+ "UpdateTensorboardTimeSeries": {
+ "methods": [
+ "update_tensorboard_time_series"
+ ]
+ },
+ "WriteTensorboardExperimentData": {
+ "methods": [
+ "write_tensorboard_experiment_data"
+ ]
+ },
+ "WriteTensorboardRunData": {
+ "methods": [
+ "write_tensorboard_run_data"
+ ]
+ }
+ }
+ }
+ }
+ },
+ "VizierService": {
+ "clients": {
+ "grpc": {
+ "libraryClient": "VizierServiceClient",
+ "rpcs": {
+ "AddTrialMeasurement": {
+ "methods": [
+ "add_trial_measurement"
+ ]
+ },
+ "CheckTrialEarlyStoppingState": {
+ "methods": [
+ "check_trial_early_stopping_state"
+ ]
+ },
+ "CompleteTrial": {
+ "methods": [
+ "complete_trial"
+ ]
+ },
+ "CreateStudy": {
+ "methods": [
+ "create_study"
+ ]
+ },
+ "CreateTrial": {
+ "methods": [
+ "create_trial"
+ ]
+ },
+ "DeleteStudy": {
+ "methods": [
+ "delete_study"
+ ]
+ },
+ "DeleteTrial": {
+ "methods": [
+ "delete_trial"
+ ]
+ },
+ "GetStudy": {
+ "methods": [
+ "get_study"
+ ]
+ },
+ "GetTrial": {
+ "methods": [
+ "get_trial"
+ ]
+ },
+ "ListOptimalTrials": {
+ "methods": [
+ "list_optimal_trials"
+ ]
+ },
+ "ListStudies": {
+ "methods": [
+ "list_studies"
+ ]
+ },
+ "ListTrials": {
+ "methods": [
+ "list_trials"
+ ]
+ },
+ "LookupStudy": {
+ "methods": [
+ "lookup_study"
+ ]
+ },
+ "StopTrial": {
+ "methods": [
+ "stop_trial"
+ ]
+ },
+ "SuggestTrials": {
+ "methods": [
+ "suggest_trials"
+ ]
+ }
+ }
+ },
+ "grpc-async": {
+ "libraryClient": "VizierServiceAsyncClient",
+ "rpcs": {
+ "AddTrialMeasurement": {
+ "methods": [
+ "add_trial_measurement"
+ ]
+ },
+ "CheckTrialEarlyStoppingState": {
+ "methods": [
+ "check_trial_early_stopping_state"
+ ]
+ },
+ "CompleteTrial": {
+ "methods": [
+ "complete_trial"
+ ]
+ },
+ "CreateStudy": {
+ "methods": [
+ "create_study"
+ ]
+ },
+ "CreateTrial": {
+ "methods": [
+ "create_trial"
+ ]
+ },
+ "DeleteStudy": {
+ "methods": [
+ "delete_study"
+ ]
+ },
+ "DeleteTrial": {
+ "methods": [
+ "delete_trial"
+ ]
+ },
+ "GetStudy": {
+ "methods": [
+ "get_study"
+ ]
+ },
+ "GetTrial": {
+ "methods": [
+ "get_trial"
+ ]
+ },
+ "ListOptimalTrials": {
+ "methods": [
+ "list_optimal_trials"
+ ]
+ },
+ "ListStudies": {
+ "methods": [
+ "list_studies"
+ ]
+ },
+ "ListTrials": {
+ "methods": [
+ "list_trials"
+ ]
+ },
+ "LookupStudy": {
+ "methods": [
+ "lookup_study"
+ ]
+ },
+ "StopTrial": {
+ "methods": [
+ "stop_trial"
+ ]
+ },
+ "SuggestTrials": {
+ "methods": [
+ "suggest_trials"
]
}
}
diff --git a/google/cloud/aiplatform_v1/services/__init__.py b/google/cloud/aiplatform_v1/services/__init__.py
index 4de65971c2..e8e1c3845d 100644
--- a/google/cloud/aiplatform_v1/services/__init__.py
+++ b/google/cloud/aiplatform_v1/services/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py
index 42adf5e5af..163172b9a0 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py
index b937183e37..51397c4189 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,16 +16,21 @@
from collections import OrderedDict
import functools
import re
-from typing import Dict, Sequence, Tuple, Type, Union
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
import pkg_resources
-import google.api_core.client_options as ClientOptions # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
+from google.api_core.client_options import ClientOptions
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.dataset_service import pagers
@@ -37,6 +42,10 @@
from google.cloud.aiplatform_v1.types import dataset_service
from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import struct_pb2 # type: ignore
@@ -47,7 +56,9 @@
class DatasetServiceAsyncClient:
- """"""
+ """The service that handles the CRUD of Vertex AI Dataset and
+ its child resources.
+ """
_client: DatasetServiceClient
@@ -91,7 +102,8 @@ class DatasetServiceAsyncClient:
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
- """Creates an instance of this client using the provided credentials info.
+ """Creates an instance of this client using the provided credentials
+ info.
Args:
info (dict): The service account private key info.
@@ -106,7 +118,7 @@ def from_service_account_info(cls, info: dict, *args, **kwargs):
@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
- file.
+ file.
Args:
filename (str): The path to the service account private key json
@@ -121,9 +133,45 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):
from_service_account_json = from_service_account_file
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ return DatasetServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
+
@property
def transport(self) -> DatasetServiceTransport:
- """Return the transport used by the client instance.
+ """Returns the transport used by the client instance.
Returns:
DatasetServiceTransport: The transport used by the client instance.
@@ -142,7 +190,7 @@ def __init__(
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiate the dataset service client.
+ """Instantiates the dataset service client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -183,18 +231,47 @@ def __init__(
async def create_dataset(
self,
- request: dataset_service.CreateDatasetRequest = None,
+ request: Union[dataset_service.CreateDatasetRequest, dict] = None,
*,
parent: str = None,
dataset: gca_dataset.Dataset = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Creates a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_create_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ dataset = aiplatform_v1.Dataset()
+ dataset.display_name = "display_name_value"
+ dataset.metadata_schema_uri = "metadata_schema_uri_value"
+ dataset.metadata.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1.CreateDatasetRequest(
+ parent="parent_value",
+ dataset=dataset,
+ )
+
+ # Make the request
+ operation = client.create_dataset(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.CreateDatasetRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.CreateDatasetRequest, dict]):
The request object. Request message for
[DatasetService.CreateDataset][google.cloud.aiplatform.v1.DatasetService.CreateDataset].
parent (:class:`str`):
@@ -226,7 +303,7 @@ async def create_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent, dataset])
if request is not None and has_flattened_params:
@@ -248,7 +325,7 @@ async def create_dataset(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.create_dataset,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -259,7 +336,12 @@ async def create_dataset(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -274,17 +356,36 @@ async def create_dataset(
async def get_dataset(
self,
- request: dataset_service.GetDatasetRequest = None,
+ request: Union[dataset_service.GetDatasetRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> dataset.Dataset:
r"""Gets a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetDatasetRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_dataset(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.GetDatasetRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.GetDatasetRequest, dict]):
The request object. Request message for
[DatasetService.GetDataset][google.cloud.aiplatform.v1.DatasetService.GetDataset].
name (:class:`str`):
@@ -307,7 +408,7 @@ async def get_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -327,7 +428,7 @@ async def get_dataset(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.get_dataset,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -338,25 +439,54 @@ async def get_dataset(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
async def update_dataset(
self,
- request: dataset_service.UpdateDatasetRequest = None,
+ request: Union[dataset_service.UpdateDatasetRequest, dict] = None,
*,
dataset: gca_dataset.Dataset = None,
update_mask: field_mask_pb2.FieldMask = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gca_dataset.Dataset:
r"""Updates a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_update_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ dataset = aiplatform_v1.Dataset()
+ dataset.display_name = "display_name_value"
+ dataset.metadata_schema_uri = "metadata_schema_uri_value"
+ dataset.metadata.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1.UpdateDatasetRequest(
+ dataset=dataset,
+ )
+
+ # Make the request
+ response = await client.update_dataset(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.UpdateDatasetRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.UpdateDatasetRequest, dict]):
The request object. Request message for
[DatasetService.UpdateDataset][google.cloud.aiplatform.v1.DatasetService.UpdateDataset].
dataset (:class:`google.cloud.aiplatform_v1.types.Dataset`):
@@ -369,7 +499,7 @@ async def update_dataset(
update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
Required. The update mask applies to the resource. For
the ``FieldMask`` definition, see
- `FieldMask `__.
+ [google.protobuf.FieldMask][google.protobuf.FieldMask].
Updatable fields:
- ``display_name``
@@ -392,7 +522,7 @@ async def update_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([dataset, update_mask])
if request is not None and has_flattened_params:
@@ -414,7 +544,7 @@ async def update_dataset(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.update_dataset,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -427,24 +557,49 @@ async def update_dataset(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
async def list_datasets(
self,
- request: dataset_service.ListDatasetsRequest = None,
+ request: Union[dataset_service.ListDatasetsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListDatasetsAsyncPager:
r"""Lists Datasets in a Location.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_datasets():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListDatasetsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_datasets(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ListDatasetsRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ListDatasetsRequest, dict]):
The request object. Request message for
[DatasetService.ListDatasets][google.cloud.aiplatform.v1.DatasetService.ListDatasets].
parent (:class:`str`):
@@ -470,7 +625,7 @@ async def list_datasets(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -490,7 +645,7 @@ async def list_datasets(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.list_datasets,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -501,12 +656,20 @@ async def list_datasets(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = pagers.ListDatasetsAsyncPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -514,17 +677,40 @@ async def list_datasets(
async def delete_dataset(
self,
- request: dataset_service.DeleteDatasetRequest = None,
+ request: Union[dataset_service.DeleteDatasetRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Deletes a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_delete_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteDatasetRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_dataset(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.DeleteDatasetRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.DeleteDatasetRequest, dict]):
The request object. Request message for
[DatasetService.DeleteDataset][google.cloud.aiplatform.v1.DatasetService.DeleteDataset].
name (:class:`str`):
@@ -561,7 +747,7 @@ async def delete_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -581,7 +767,7 @@ async def delete_dataset(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.delete_dataset,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -592,7 +778,12 @@ async def delete_dataset(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -607,18 +798,46 @@ async def delete_dataset(
async def import_data(
self,
- request: dataset_service.ImportDataRequest = None,
+ request: Union[dataset_service.ImportDataRequest, dict] = None,
*,
name: str = None,
import_configs: Sequence[dataset.ImportDataConfig] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Imports data into a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_import_data():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ import_configs = aiplatform_v1.ImportDataConfig()
+ import_configs.gcs_source.uris = ['uris_value_1', 'uris_value_2']
+ import_configs.import_schema_uri = "import_schema_uri_value"
+
+ request = aiplatform_v1.ImportDataRequest(
+ name="name_value",
+ import_configs=import_configs,
+ )
+
+ # Make the request
+ operation = client.import_data(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ImportDataRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ImportDataRequest, dict]):
The request object. Request message for
[DatasetService.ImportData][google.cloud.aiplatform.v1.DatasetService.ImportData].
name (:class:`str`):
@@ -653,7 +872,7 @@ async def import_data(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name, import_configs])
if request is not None and has_flattened_params:
@@ -675,7 +894,7 @@ async def import_data(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.import_data,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -686,7 +905,12 @@ async def import_data(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -701,18 +925,45 @@ async def import_data(
async def export_data(
self,
- request: dataset_service.ExportDataRequest = None,
+ request: Union[dataset_service.ExportDataRequest, dict] = None,
*,
name: str = None,
export_config: dataset.ExportDataConfig = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Exports data from a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_export_data():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ export_config = aiplatform_v1.ExportDataConfig()
+ export_config.gcs_destination.output_uri_prefix = "output_uri_prefix_value"
+
+ request = aiplatform_v1.ExportDataRequest(
+ name="name_value",
+ export_config=export_config,
+ )
+
+ # Make the request
+ operation = client.export_data(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ExportDataRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ExportDataRequest, dict]):
The request object. Request message for
[DatasetService.ExportData][google.cloud.aiplatform.v1.DatasetService.ExportData].
name (:class:`str`):
@@ -746,7 +997,7 @@ async def export_data(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name, export_config])
if request is not None and has_flattened_params:
@@ -768,7 +1019,7 @@ async def export_data(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.export_data,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -779,7 +1030,12 @@ async def export_data(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -794,17 +1050,37 @@ async def export_data(
async def list_data_items(
self,
- request: dataset_service.ListDataItemsRequest = None,
+ request: Union[dataset_service.ListDataItemsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListDataItemsAsyncPager:
r"""Lists DataItems in a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_data_items():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListDataItemsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_data_items(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ListDataItemsRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ListDataItemsRequest, dict]):
The request object. Request message for
[DatasetService.ListDataItems][google.cloud.aiplatform.v1.DatasetService.ListDataItems].
parent (:class:`str`):
@@ -831,7 +1107,7 @@ async def list_data_items(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -851,7 +1127,7 @@ async def list_data_items(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.list_data_items,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -862,12 +1138,20 @@ async def list_data_items(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = pagers.ListDataItemsAsyncPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -875,23 +1159,41 @@ async def list_data_items(
async def get_annotation_spec(
self,
- request: dataset_service.GetAnnotationSpecRequest = None,
+ request: Union[dataset_service.GetAnnotationSpecRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> annotation_spec.AnnotationSpec:
r"""Gets an AnnotationSpec.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_annotation_spec():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetAnnotationSpecRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_annotation_spec(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest, dict]):
The request object. Request message for
[DatasetService.GetAnnotationSpec][google.cloud.aiplatform.v1.DatasetService.GetAnnotationSpec].
name (:class:`str`):
Required. The name of the AnnotationSpec resource.
Format:
-
``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}``
This corresponds to the ``name`` field
@@ -910,7 +1212,7 @@ async def get_annotation_spec(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -930,7 +1232,7 @@ async def get_annotation_spec(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.get_annotation_spec,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -941,30 +1243,54 @@ async def get_annotation_spec(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
async def list_annotations(
self,
- request: dataset_service.ListAnnotationsRequest = None,
+ request: Union[dataset_service.ListAnnotationsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListAnnotationsAsyncPager:
r"""Lists Annotations belongs to a dataitem
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_annotations():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListAnnotationsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_annotations(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ListAnnotationsRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ListAnnotationsRequest, dict]):
The request object. Request message for
[DatasetService.ListAnnotations][google.cloud.aiplatform.v1.DatasetService.ListAnnotations].
parent (:class:`str`):
Required. The resource name of the DataItem to list
Annotations from. Format:
-
``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}``
This corresponds to the ``parent`` field
@@ -986,7 +1312,7 @@ async def list_annotations(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -1006,7 +1332,7 @@ async def list_annotations(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.list_annotations,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -1017,17 +1343,702 @@ async def list_annotations(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = pagers.ListAnnotationsAsyncPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
)
# Done; return the response.
return response
+ async def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.transport.close()
+
try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py
index 201d814c99..08ae4b98cf 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/client.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +14,26 @@
# limitations under the License.
#
from collections import OrderedDict
-from distutils import util
import os
import re
-from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
import pkg_resources
-from google.api_core import client_options as client_options_lib # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
+from google.api_core import client_options as client_options_lib
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport import mtls # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth.exceptions import MutualTLSChannelError # type: ignore
from google.oauth2 import service_account # type: ignore
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.dataset_service import pagers
@@ -41,6 +45,10 @@
from google.cloud.aiplatform_v1.types import dataset_service
from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import struct_pb2 # type: ignore
@@ -64,8 +72,11 @@ class DatasetServiceClientMeta(type):
_transport_registry["grpc"] = DatasetServiceGrpcTransport
_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport
- def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]:
- """Return an appropriate transport class.
+ def get_transport_class(
+ cls,
+ label: str = None,
+ ) -> Type[DatasetServiceTransport]:
+ """Returns an appropriate transport class.
Args:
label: The name of the desired transport. If none is
@@ -84,11 +95,14 @@ def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport
class DatasetServiceClient(metaclass=DatasetServiceClientMeta):
- """"""
+ """The service that handles the CRUD of Vertex AI Dataset and
+ its child resources.
+ """
@staticmethod
def _get_default_mtls_endpoint(api_endpoint):
- """Convert api endpoint to mTLS endpoint.
+ """Converts api endpoint to mTLS endpoint.
+
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
@@ -122,7 +136,8 @@ def _get_default_mtls_endpoint(api_endpoint):
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
- """Creates an instance of this client using the provided credentials info.
+ """Creates an instance of this client using the provided credentials
+ info.
Args:
info (dict): The service account private key info.
@@ -139,7 +154,7 @@ def from_service_account_info(cls, info: dict, *args, **kwargs):
@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
- file.
+ file.
Args:
filename (str): The path to the service account private key json
@@ -158,18 +173,23 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):
@property
def transport(self) -> DatasetServiceTransport:
- """Return the transport used by the client instance.
+ """Returns the transport used by the client instance.
Returns:
- DatasetServiceTransport: The transport used by the client instance.
+ DatasetServiceTransport: The transport used by the client
+ instance.
"""
return self._transport
@staticmethod
def annotation_path(
- project: str, location: str, dataset: str, data_item: str, annotation: str,
+ project: str,
+ location: str,
+ dataset: str,
+ data_item: str,
+ annotation: str,
) -> str:
- """Return a fully-qualified annotation string."""
+ """Returns a fully-qualified annotation string."""
return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(
project=project,
location=location,
@@ -180,7 +200,7 @@ def annotation_path(
@staticmethod
def parse_annotation_path(path: str) -> Dict[str, str]:
- """Parse a annotation path into its component segments."""
+ """Parses a annotation path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$",
path,
@@ -189,9 +209,12 @@ def parse_annotation_path(path: str) -> Dict[str, str]:
@staticmethod
def annotation_spec_path(
- project: str, location: str, dataset: str, annotation_spec: str,
+ project: str,
+ location: str,
+ dataset: str,
+ annotation_spec: str,
) -> str:
- """Return a fully-qualified annotation_spec string."""
+ """Returns a fully-qualified annotation_spec string."""
return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(
project=project,
location=location,
@@ -201,7 +224,7 @@ def annotation_spec_path(
@staticmethod
def parse_annotation_spec_path(path: str) -> Dict[str, str]:
- """Parse a annotation_spec path into its component segments."""
+ """Parses a annotation_spec path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$",
path,
@@ -210,16 +233,22 @@ def parse_annotation_spec_path(path: str) -> Dict[str, str]:
@staticmethod
def data_item_path(
- project: str, location: str, dataset: str, data_item: str,
+ project: str,
+ location: str,
+ dataset: str,
+ data_item: str,
) -> str:
- """Return a fully-qualified data_item string."""
+ """Returns a fully-qualified data_item string."""
return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(
- project=project, location=location, dataset=dataset, data_item=data_item,
+ project=project,
+ location=location,
+ dataset=dataset,
+ data_item=data_item,
)
@staticmethod
def parse_data_item_path(path: str) -> Dict[str, str]:
- """Parse a data_item path into its component segments."""
+ """Parses a data_item path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$",
path,
@@ -227,15 +256,21 @@ def parse_data_item_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def dataset_path(project: str, location: str, dataset: str,) -> str:
- """Return a fully-qualified dataset string."""
+ def dataset_path(
+ project: str,
+ location: str,
+ dataset: str,
+ ) -> str:
+ """Returns a fully-qualified dataset string."""
return "projects/{project}/locations/{location}/datasets/{dataset}".format(
- project=project, location=location, dataset=dataset,
+ project=project,
+ location=location,
+ dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
- """Parse a dataset path into its component segments."""
+ """Parses a dataset path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
path,
@@ -243,8 +278,10 @@ def parse_dataset_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_billing_account_path(billing_account: str,) -> str:
- """Return a fully-qualified billing_account string."""
+ def common_billing_account_path(
+ billing_account: str,
+ ) -> str:
+ """Returns a fully-qualified billing_account string."""
return "billingAccounts/{billing_account}".format(
billing_account=billing_account,
)
@@ -256,9 +293,13 @@ def parse_common_billing_account_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_folder_path(folder: str,) -> str:
- """Return a fully-qualified folder string."""
- return "folders/{folder}".format(folder=folder,)
+ def common_folder_path(
+ folder: str,
+ ) -> str:
+ """Returns a fully-qualified folder string."""
+ return "folders/{folder}".format(
+ folder=folder,
+ )
@staticmethod
def parse_common_folder_path(path: str) -> Dict[str, str]:
@@ -267,9 +308,13 @@ def parse_common_folder_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_organization_path(organization: str,) -> str:
- """Return a fully-qualified organization string."""
- return "organizations/{organization}".format(organization=organization,)
+ def common_organization_path(
+ organization: str,
+ ) -> str:
+ """Returns a fully-qualified organization string."""
+ return "organizations/{organization}".format(
+ organization=organization,
+ )
@staticmethod
def parse_common_organization_path(path: str) -> Dict[str, str]:
@@ -278,9 +323,13 @@ def parse_common_organization_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_project_path(project: str,) -> str:
- """Return a fully-qualified project string."""
- return "projects/{project}".format(project=project,)
+ def common_project_path(
+ project: str,
+ ) -> str:
+ """Returns a fully-qualified project string."""
+ return "projects/{project}".format(
+ project=project,
+ )
@staticmethod
def parse_common_project_path(path: str) -> Dict[str, str]:
@@ -289,10 +338,14 @@ def parse_common_project_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_location_path(project: str, location: str,) -> str:
- """Return a fully-qualified location string."""
+ def common_location_path(
+ project: str,
+ location: str,
+ ) -> str:
+ """Returns a fully-qualified location string."""
return "projects/{project}/locations/{location}".format(
- project=project, location=location,
+ project=project,
+ location=location,
)
@staticmethod
@@ -301,6 +354,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]:
m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path)
return m.groupdict() if m else {}
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[client_options_lib.ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ if client_options is None:
+ client_options = client_options_lib.ClientOptions()
+ use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Figure out the client cert source to use.
+ client_cert_source = None
+ if use_client_cert == "true":
+ if client_options.client_cert_source:
+ client_cert_source = client_options.client_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+
+ # Figure out which api endpoint to use.
+ if client_options.api_endpoint is not None:
+ api_endpoint = client_options.api_endpoint
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = cls.DEFAULT_ENDPOINT
+
+ return api_endpoint, client_cert_source
+
def __init__(
self,
*,
@@ -309,7 +429,7 @@ def __init__(
client_options: Optional[client_options_lib.ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiate the dataset service client.
+ """Instantiates the dataset service client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -351,58 +471,42 @@ def __init__(
if client_options is None:
client_options = client_options_lib.ClientOptions()
- # Create SSL credentials for mutual TLS if needed.
- use_client_cert = bool(
- util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
+ api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
+ client_options
)
- client_cert_source_func = None
- is_mtls = False
- if use_client_cert:
- if client_options.client_cert_source:
- is_mtls = True
- client_cert_source_func = client_options.client_cert_source
- else:
- is_mtls = mtls.has_default_client_cert_source()
- client_cert_source_func = (
- mtls.default_client_cert_source() if is_mtls else None
- )
-
- # Figure out which api endpoint to use.
- if client_options.api_endpoint is not None:
- api_endpoint = client_options.api_endpoint
- else:
- use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
- if use_mtls_env == "never":
- api_endpoint = self.DEFAULT_ENDPOINT
- elif use_mtls_env == "always":
- api_endpoint = self.DEFAULT_MTLS_ENDPOINT
- elif use_mtls_env == "auto":
- api_endpoint = (
- self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT
- )
- else:
- raise MutualTLSChannelError(
- "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always"
- )
+ api_key_value = getattr(client_options, "api_key", None)
+ if api_key_value and credentials:
+ raise ValueError(
+ "client_options.api_key and credentials are mutually exclusive"
+ )
# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
if isinstance(transport, DatasetServiceTransport):
# transport is a DatasetServiceTransport instance.
- if credentials or client_options.credentials_file:
+ if credentials or client_options.credentials_file or api_key_value:
raise ValueError(
"When providing a transport instance, "
"provide its credentials directly."
)
if client_options.scopes:
raise ValueError(
- "When providing a transport instance, "
- "provide its scopes directly."
+ "When providing a transport instance, provide its scopes "
+ "directly."
)
self._transport = transport
else:
+ import google.auth._default # type: ignore
+
+ if api_key_value and hasattr(
+ google.auth._default, "get_api_key_credentials"
+ ):
+ credentials = google.auth._default.get_api_key_credentials(
+ api_key_value
+ )
+
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials,
@@ -412,22 +516,52 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
+ always_use_jwt_access=True,
)
def create_dataset(
self,
- request: dataset_service.CreateDatasetRequest = None,
+ request: Union[dataset_service.CreateDatasetRequest, dict] = None,
*,
parent: str = None,
dataset: gca_dataset.Dataset = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Creates a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_create_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ dataset = aiplatform_v1.Dataset()
+ dataset.display_name = "display_name_value"
+ dataset.metadata_schema_uri = "metadata_schema_uri_value"
+ dataset.metadata.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1.CreateDatasetRequest(
+ parent="parent_value",
+ dataset=dataset,
+ )
+
+ # Make the request
+ operation = client.create_dataset(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.CreateDatasetRequest):
+ request (Union[google.cloud.aiplatform_v1.types.CreateDatasetRequest, dict]):
The request object. Request message for
[DatasetService.CreateDataset][google.cloud.aiplatform.v1.DatasetService.CreateDataset].
parent (str):
@@ -459,7 +593,7 @@ def create_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent, dataset])
if request is not None and has_flattened_params:
@@ -492,7 +626,12 @@ def create_dataset(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -507,17 +646,36 @@ def create_dataset(
def get_dataset(
self,
- request: dataset_service.GetDatasetRequest = None,
+ request: Union[dataset_service.GetDatasetRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> dataset.Dataset:
r"""Gets a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_get_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetDatasetRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = client.get_dataset(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.GetDatasetRequest):
+ request (Union[google.cloud.aiplatform_v1.types.GetDatasetRequest, dict]):
The request object. Request message for
[DatasetService.GetDataset][google.cloud.aiplatform.v1.DatasetService.GetDataset].
name (str):
@@ -540,7 +698,7 @@ def get_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -571,25 +729,54 @@ def get_dataset(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
def update_dataset(
self,
- request: dataset_service.UpdateDatasetRequest = None,
+ request: Union[dataset_service.UpdateDatasetRequest, dict] = None,
*,
dataset: gca_dataset.Dataset = None,
update_mask: field_mask_pb2.FieldMask = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gca_dataset.Dataset:
r"""Updates a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_update_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ dataset = aiplatform_v1.Dataset()
+ dataset.display_name = "display_name_value"
+ dataset.metadata_schema_uri = "metadata_schema_uri_value"
+ dataset.metadata.null_value = "NULL_VALUE"
+
+ request = aiplatform_v1.UpdateDatasetRequest(
+ dataset=dataset,
+ )
+
+ # Make the request
+ response = client.update_dataset(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.UpdateDatasetRequest):
+ request (Union[google.cloud.aiplatform_v1.types.UpdateDatasetRequest, dict]):
The request object. Request message for
[DatasetService.UpdateDataset][google.cloud.aiplatform.v1.DatasetService.UpdateDataset].
dataset (google.cloud.aiplatform_v1.types.Dataset):
@@ -602,7 +789,7 @@ def update_dataset(
update_mask (google.protobuf.field_mask_pb2.FieldMask):
Required. The update mask applies to the resource. For
the ``FieldMask`` definition, see
- `FieldMask `__.
+ [google.protobuf.FieldMask][google.protobuf.FieldMask].
Updatable fields:
- ``display_name``
@@ -625,7 +812,7 @@ def update_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([dataset, update_mask])
if request is not None and has_flattened_params:
@@ -660,24 +847,49 @@ def update_dataset(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
def list_datasets(
self,
- request: dataset_service.ListDatasetsRequest = None,
+ request: Union[dataset_service.ListDatasetsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListDatasetsPager:
r"""Lists Datasets in a Location.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_list_datasets():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListDatasetsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_datasets(request=request)
+
+ # Handle the response
+ for response in page_result:
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ListDatasetsRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ListDatasetsRequest, dict]):
The request object. Request message for
[DatasetService.ListDatasets][google.cloud.aiplatform.v1.DatasetService.ListDatasets].
parent (str):
@@ -703,7 +915,7 @@ def list_datasets(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -734,12 +946,20 @@ def list_datasets(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__iter__` convenience method.
response = pagers.ListDatasetsPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -747,17 +967,40 @@ def list_datasets(
def delete_dataset(
self,
- request: dataset_service.DeleteDatasetRequest = None,
+ request: Union[dataset_service.DeleteDatasetRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Deletes a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_delete_dataset():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteDatasetRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_dataset(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.DeleteDatasetRequest):
+ request (Union[google.cloud.aiplatform_v1.types.DeleteDatasetRequest, dict]):
The request object. Request message for
[DatasetService.DeleteDataset][google.cloud.aiplatform.v1.DatasetService.DeleteDataset].
name (str):
@@ -794,7 +1037,7 @@ def delete_dataset(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -825,7 +1068,12 @@ def delete_dataset(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -840,18 +1088,46 @@ def delete_dataset(
def import_data(
self,
- request: dataset_service.ImportDataRequest = None,
+ request: Union[dataset_service.ImportDataRequest, dict] = None,
*,
name: str = None,
import_configs: Sequence[dataset.ImportDataConfig] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Imports data into a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_import_data():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ import_configs = aiplatform_v1.ImportDataConfig()
+ import_configs.gcs_source.uris = ['uris_value_1', 'uris_value_2']
+ import_configs.import_schema_uri = "import_schema_uri_value"
+
+ request = aiplatform_v1.ImportDataRequest(
+ name="name_value",
+ import_configs=import_configs,
+ )
+
+ # Make the request
+ operation = client.import_data(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ImportDataRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ImportDataRequest, dict]):
The request object. Request message for
[DatasetService.ImportData][google.cloud.aiplatform.v1.DatasetService.ImportData].
name (str):
@@ -886,7 +1162,7 @@ def import_data(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name, import_configs])
if request is not None and has_flattened_params:
@@ -919,7 +1195,12 @@ def import_data(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -934,18 +1215,45 @@ def import_data(
def export_data(
self,
- request: dataset_service.ExportDataRequest = None,
+ request: Union[dataset_service.ExportDataRequest, dict] = None,
*,
name: str = None,
export_config: dataset.ExportDataConfig = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Exports data from a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_export_data():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ export_config = aiplatform_v1.ExportDataConfig()
+ export_config.gcs_destination.output_uri_prefix = "output_uri_prefix_value"
+
+ request = aiplatform_v1.ExportDataRequest(
+ name="name_value",
+ export_config=export_config,
+ )
+
+ # Make the request
+ operation = client.export_data(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ExportDataRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ExportDataRequest, dict]):
The request object. Request message for
[DatasetService.ExportData][google.cloud.aiplatform.v1.DatasetService.ExportData].
name (str):
@@ -979,7 +1287,7 @@ def export_data(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name, export_config])
if request is not None and has_flattened_params:
@@ -1012,7 +1320,12 @@ def export_data(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -1027,17 +1340,37 @@ def export_data(
def list_data_items(
self,
- request: dataset_service.ListDataItemsRequest = None,
+ request: Union[dataset_service.ListDataItemsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListDataItemsPager:
r"""Lists DataItems in a Dataset.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_list_data_items():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListDataItemsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_data_items(request=request)
+
+ # Handle the response
+ for response in page_result:
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ListDataItemsRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ListDataItemsRequest, dict]):
The request object. Request message for
[DatasetService.ListDataItems][google.cloud.aiplatform.v1.DatasetService.ListDataItems].
parent (str):
@@ -1064,7 +1397,7 @@ def list_data_items(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -1095,12 +1428,20 @@ def list_data_items(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__iter__` convenience method.
response = pagers.ListDataItemsPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -1108,23 +1449,41 @@ def list_data_items(
def get_annotation_spec(
self,
- request: dataset_service.GetAnnotationSpecRequest = None,
+ request: Union[dataset_service.GetAnnotationSpecRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> annotation_spec.AnnotationSpec:
r"""Gets an AnnotationSpec.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_get_annotation_spec():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetAnnotationSpecRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = client.get_annotation_spec(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest):
+ request (Union[google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest, dict]):
The request object. Request message for
[DatasetService.GetAnnotationSpec][google.cloud.aiplatform.v1.DatasetService.GetAnnotationSpec].
name (str):
Required. The name of the AnnotationSpec resource.
Format:
-
``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}``
This corresponds to the ``name`` field
@@ -1143,7 +1502,7 @@ def get_annotation_spec(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -1174,30 +1533,54 @@ def get_annotation_spec(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
def list_annotations(
self,
- request: dataset_service.ListAnnotationsRequest = None,
+ request: Union[dataset_service.ListAnnotationsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListAnnotationsPager:
r"""Lists Annotations belongs to a dataitem
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_list_annotations():
+ # Create a client
+ client = aiplatform_v1.DatasetServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListAnnotationsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_annotations(request=request)
+
+ # Handle the response
+ for response in page_result:
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ListAnnotationsRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ListAnnotationsRequest, dict]):
The request object. Request message for
[DatasetService.ListAnnotations][google.cloud.aiplatform.v1.DatasetService.ListAnnotations].
parent (str):
Required. The resource name of the DataItem to list
Annotations from. Format:
-
``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}``
This corresponds to the ``parent`` field
@@ -1219,7 +1602,7 @@ def list_annotations(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -1250,12 +1633,704 @@ def list_annotations(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__iter__` convenience method.
response = pagers.ListAnnotationsPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """Releases underlying transport's resources.
+
+ .. warning::
+ ONLY use as a context manager if the transport is NOT shared
+ with other clients! Exiting the with block will CLOSE the transport
+ and may cause errors in other clients!
+ """
+ self.transport.close()
+
+ def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
)
# Done; return the response.
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py
index be142bd36e..577519f725 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,13 @@
#
from typing import (
Any,
- AsyncIterable,
+ AsyncIterator,
Awaitable,
Callable,
- Iterable,
Sequence,
Tuple,
Optional,
+ Iterator,
)
from google.cloud.aiplatform_v1.types import annotation
@@ -77,14 +77,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- def pages(self) -> Iterable[dataset_service.ListDatasetsResponse]:
+ def pages(self) -> Iterator[dataset_service.ListDatasetsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response
- def __iter__(self) -> Iterable[dataset.Dataset]:
+ def __iter__(self) -> Iterator[dataset.Dataset]:
for page in self.pages:
yield from page.datasets
@@ -118,7 +118,7 @@ def __init__(
*,
metadata: Sequence[Tuple[str, str]] = ()
):
- """Instantiate the pager.
+ """Instantiates the pager.
Args:
method (Callable): The method that was originally called, and
@@ -139,14 +139,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- async def pages(self) -> AsyncIterable[dataset_service.ListDatasetsResponse]:
+ async def pages(self) -> AsyncIterator[dataset_service.ListDatasetsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response
- def __aiter__(self) -> AsyncIterable[dataset.Dataset]:
+ def __aiter__(self) -> AsyncIterator[dataset.Dataset]:
async def async_generator():
async for page in self.pages:
for response in page.datasets:
@@ -205,14 +205,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- def pages(self) -> Iterable[dataset_service.ListDataItemsResponse]:
+ def pages(self) -> Iterator[dataset_service.ListDataItemsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response
- def __iter__(self) -> Iterable[data_item.DataItem]:
+ def __iter__(self) -> Iterator[data_item.DataItem]:
for page in self.pages:
yield from page.data_items
@@ -246,7 +246,7 @@ def __init__(
*,
metadata: Sequence[Tuple[str, str]] = ()
):
- """Instantiate the pager.
+ """Instantiates the pager.
Args:
method (Callable): The method that was originally called, and
@@ -267,14 +267,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- async def pages(self) -> AsyncIterable[dataset_service.ListDataItemsResponse]:
+ async def pages(self) -> AsyncIterator[dataset_service.ListDataItemsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response
- def __aiter__(self) -> AsyncIterable[data_item.DataItem]:
+ def __aiter__(self) -> AsyncIterator[data_item.DataItem]:
async def async_generator():
async for page in self.pages:
for response in page.data_items:
@@ -333,14 +333,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- def pages(self) -> Iterable[dataset_service.ListAnnotationsResponse]:
+ def pages(self) -> Iterator[dataset_service.ListAnnotationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response
- def __iter__(self) -> Iterable[annotation.Annotation]:
+ def __iter__(self) -> Iterator[annotation.Annotation]:
for page in self.pages:
yield from page.annotations
@@ -374,7 +374,7 @@ def __init__(
*,
metadata: Sequence[Tuple[str, str]] = ()
):
- """Instantiate the pager.
+ """Instantiates the pager.
Args:
method (Callable): The method that was originally called, and
@@ -395,14 +395,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- async def pages(self) -> AsyncIterable[dataset_service.ListAnnotationsResponse]:
+ async def pages(self) -> AsyncIterator[dataset_service.ListAnnotationsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response
- def __aiter__(self) -> AsyncIterable[annotation.Annotation]:
+ def __aiter__(self) -> AsyncIterator[annotation.Annotation]:
async def async_generator():
async for page in self.pages:
for response in page.annotations:
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py
index 902a4fb01f..07bc11c0c8 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py
index c049ed37ba..aa75e138ea 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,21 +15,25 @@
#
import abc
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
-import packaging.version
import pkg_resources
import google.auth # type: ignore
-import google.api_core # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
-from google.api_core import operations_v1 # type: ignore
+import google.api_core
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
from google.cloud.aiplatform_v1.types import annotation_spec
from google.cloud.aiplatform_v1.types import dataset
from google.cloud.aiplatform_v1.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import dataset_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
try:
@@ -41,17 +45,6 @@
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
-try:
- # google.auth.__version__ was added in 1.26.0
- _GOOGLE_AUTH_VERSION = google.auth.__version__
-except AttributeError:
- try: # try pkg_resources if it is available
- _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
- except pkg_resources.DistributionNotFound: # pragma: NO COVER
- _GOOGLE_AUTH_VERSION = None
-
-_API_CORE_VERSION = google.api_core.__version__
-
class DatasetServiceTransport(abc.ABC):
"""Abstract transport class for DatasetService."""
@@ -69,6 +62,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
@@ -92,16 +86,19 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
"""
+
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
host += ":443"
self._host = host
- scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)
+ scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
# Save the scopes.
- self._scopes = scopes or self.AUTH_SCOPES
+ self._scopes = scopes
# If no credentials are provided, then determine the appropriate
# defaults.
@@ -114,97 +111,88 @@ def __init__(
credentials, _ = google.auth.load_credentials_from_file(
credentials_file, **scopes_kwargs, quota_project_id=quota_project_id
)
-
elif credentials is None:
credentials, _ = google.auth.default(
**scopes_kwargs, quota_project_id=quota_project_id
)
- # Save the credentials.
- self._credentials = credentials
-
- # TODO(busunkim): These two class methods are in the base transport
- # to avoid duplicating code across the transport classes. These functions
- # should be deleted once the minimum required versions of google-api-core
- # and google-auth are increased.
-
- # TODO: Remove this function once google-auth >= 1.25.0 is required
- @classmethod
- def _get_scopes_kwargs(
- cls, host: str, scopes: Optional[Sequence[str]]
- ) -> Dict[str, Optional[Sequence[str]]]:
- """Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""
-
- scopes_kwargs = {}
-
- if _GOOGLE_AUTH_VERSION and (
- packaging.version.parse(_GOOGLE_AUTH_VERSION)
- >= packaging.version.parse("1.25.0")
+ # If the credentials are service account credentials, then always try to use self signed JWT.
+ if (
+ always_use_jwt_access
+ and isinstance(credentials, service_account.Credentials)
+ and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
- scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
- else:
- scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}
-
- return scopes_kwargs
+ credentials = credentials.with_always_use_jwt_access(True)
- # TODO: Remove this function once google-api-core >= 1.26.0 is required
- @classmethod
- def _get_self_signed_jwt_kwargs(
- cls, host: str, scopes: Optional[Sequence[str]]
- ) -> Dict[str, Union[Optional[Sequence[str]], str]]:
- """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""
-
- self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}
-
- if _API_CORE_VERSION and (
- packaging.version.parse(_API_CORE_VERSION)
- >= packaging.version.parse("1.26.0")
- ):
- self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
- self_signed_jwt_kwargs["scopes"] = scopes
- self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
- else:
- self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES
-
- return self_signed_jwt_kwargs
+ # Save the credentials.
+ self._credentials = credentials
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
self.create_dataset: gapic_v1.method.wrap_method(
- self.create_dataset, default_timeout=5.0, client_info=client_info,
+ self.create_dataset,
+ default_timeout=None,
+ client_info=client_info,
),
self.get_dataset: gapic_v1.method.wrap_method(
- self.get_dataset, default_timeout=5.0, client_info=client_info,
+ self.get_dataset,
+ default_timeout=None,
+ client_info=client_info,
),
self.update_dataset: gapic_v1.method.wrap_method(
- self.update_dataset, default_timeout=5.0, client_info=client_info,
+ self.update_dataset,
+ default_timeout=None,
+ client_info=client_info,
),
self.list_datasets: gapic_v1.method.wrap_method(
- self.list_datasets, default_timeout=5.0, client_info=client_info,
+ self.list_datasets,
+ default_timeout=None,
+ client_info=client_info,
),
self.delete_dataset: gapic_v1.method.wrap_method(
- self.delete_dataset, default_timeout=5.0, client_info=client_info,
+ self.delete_dataset,
+ default_timeout=None,
+ client_info=client_info,
),
self.import_data: gapic_v1.method.wrap_method(
- self.import_data, default_timeout=5.0, client_info=client_info,
+ self.import_data,
+ default_timeout=None,
+ client_info=client_info,
),
self.export_data: gapic_v1.method.wrap_method(
- self.export_data, default_timeout=5.0, client_info=client_info,
+ self.export_data,
+ default_timeout=None,
+ client_info=client_info,
),
self.list_data_items: gapic_v1.method.wrap_method(
- self.list_data_items, default_timeout=5.0, client_info=client_info,
+ self.list_data_items,
+ default_timeout=None,
+ client_info=client_info,
),
self.get_annotation_spec: gapic_v1.method.wrap_method(
- self.get_annotation_spec, default_timeout=5.0, client_info=client_info,
+ self.get_annotation_spec,
+ default_timeout=None,
+ client_info=client_info,
),
self.list_annotations: gapic_v1.method.wrap_method(
- self.list_annotations, default_timeout=5.0, client_info=client_info,
+ self.list_annotations,
+ default_timeout=None,
+ client_info=client_info,
),
}
+ def close(self):
+ """Closes resources associated with the transport.
+
+ .. warning::
+ Only call this method if the transport is NOT shared
+ with other clients - this may cause errors in other clients!
+ """
+ raise NotImplementedError()
+
@property
- def operations_client(self) -> operations_v1.OperationsClient:
+ def operations_client(self):
"""Return the client designed to process long-running operations."""
raise NotImplementedError()
@@ -309,5 +297,102 @@ def list_annotations(
]:
raise NotImplementedError()
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest],
+ Union[
+ operations_pb2.ListOperationsResponse,
+ Awaitable[operations_pb2.ListOperationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.GetOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.WaitOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.SetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.GetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ Union[
+ iam_policy_pb2.TestIamPermissionsResponse,
+ Awaitable[iam_policy_pb2.TestIamPermissionsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[
+ [locations_pb2.GetLocationRequest],
+ Union[locations_pb2.Location, Awaitable[locations_pb2.Location]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest],
+ Union[
+ locations_pb2.ListLocationsResponse,
+ Awaitable[locations_pb2.ListLocationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def kind(self) -> str:
+ raise NotImplementedError()
+
__all__ = ("DatasetServiceTransport",)
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py
index d38f841665..137b823324 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +16,9 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-from google.api_core import grpc_helpers # type: ignore
-from google.api_core import operations_v1 # type: ignore
-from google.api_core import gapic_v1 # type: ignore
+from google.api_core import grpc_helpers
+from google.api_core import operations_v1
+from google.api_core import gapic_v1
import google.auth # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
@@ -29,6 +29,10 @@
from google.cloud.aiplatform_v1.types import dataset
from google.cloud.aiplatform_v1.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import dataset_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO
@@ -36,6 +40,9 @@
class DatasetServiceGrpcTransport(DatasetServiceTransport):
"""gRPC backend transport for DatasetService.
+ The service that handles the CRUD of Vertex AI Dataset and
+ its child resources.
+
This class defines the same methods as the primary client, so the
primary client can load the underlying transport implementation
and call it.
@@ -60,6 +67,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.
@@ -82,16 +90,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
- ``client_cert_source`` or applicatin default SSL credentials.
+ ``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
- for grpc channel. It is ignored if ``channel`` is provided.
+ for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
- both in PEM format. It is used to configure mutual TLS channel. It is
+ both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
@@ -100,6 +108,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
@@ -110,7 +120,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
- self._operations_client = None
+ self._operations_client: Optional[operations_v1.OperationsClient] = None
if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
@@ -153,13 +163,17 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
)
if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
self._host,
+ # use the credentials which are saved
credentials=self._credentials,
- credentials_file=credentials_file,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
@@ -208,21 +222,20 @@ def create_channel(
and ``credentials_file`` are passed.
"""
- self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)
-
return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
- **self_signed_jwt_kwargs,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
**kwargs,
)
@property
def grpc_channel(self) -> grpc.Channel:
- """Return the channel designed to connect to this service.
- """
+ """Return the channel designed to connect to this service."""
return self._grpc_channel
@property
@@ -232,7 +245,7 @@ def operations_client(self) -> operations_v1.OperationsClient:
This property caches on the instance; repeated calls return the same
client.
"""
- # Sanity check: Only create a new client if we do not already have one.
+ # Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsClient(self.grpc_channel)
@@ -508,5 +521,215 @@ def list_annotations(
)
return self._stubs["list_annotations"]
+ def close(self):
+ self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
+ @property
+ def kind(self) -> str:
+ return "grpc"
+
__all__ = ("DatasetServiceGrpcTransport",)
diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py
index dad35d6eca..f3ace0efc9 100644
--- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,12 +16,11 @@
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import grpc_helpers_async # type: ignore
-from google.api_core import operations_v1 # type: ignore
+from google.api_core import gapic_v1
+from google.api_core import grpc_helpers_async
+from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
-import packaging.version
import grpc # type: ignore
from grpc.experimental import aio # type: ignore
@@ -30,6 +29,10 @@
from google.cloud.aiplatform_v1.types import dataset
from google.cloud.aiplatform_v1.types import dataset as gca_dataset
from google.cloud.aiplatform_v1.types import dataset_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO
from .grpc import DatasetServiceGrpcTransport
@@ -38,6 +41,9 @@
class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport):
"""gRPC AsyncIO backend transport for DatasetService.
+ The service that handles the CRUD of Vertex AI Dataset and
+ its child resources.
+
This class defines the same methods as the primary client, so the
primary client can load the underlying transport implementation
and call it.
@@ -81,14 +87,14 @@ def create_channel(
aio.Channel: A gRPC AsyncIO channel object.
"""
- self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)
-
return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
- **self_signed_jwt_kwargs,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
**kwargs,
)
@@ -106,6 +112,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.
@@ -129,16 +136,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
- ``client_cert_source`` or applicatin default SSL credentials.
+ ``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
- for grpc channel. It is ignored if ``channel`` is provided.
+ for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
- both in PEM format. It is used to configure mutual TLS channel. It is
+ both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
@@ -147,6 +154,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
@@ -157,7 +166,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
- self._operations_client = None
+ self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None
if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
@@ -199,13 +208,17 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
)
if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
self._host,
+ # use the credentials which are saved
credentials=self._credentials,
- credentials_file=credentials_file,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
@@ -235,7 +248,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient:
This property caches on the instance; repeated calls return the same
client.
"""
- # Sanity check: Only create a new client if we do not already have one.
+ # Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsAsyncClient(
self.grpc_channel
@@ -526,5 +539,211 @@ def list_annotations(
)
return self._stubs["list_annotations"]
+ def close(self):
+ return self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
__all__ = ("DatasetServiceGrpcAsyncIOTransport",)
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py
index 96fb4ad6d6..3c37159f9d 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py
index 544a7788df..87a4c32dee 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,16 +16,21 @@
from collections import OrderedDict
import functools
import re
-from typing import Dict, Sequence, Tuple, Type, Union
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
import pkg_resources
-import google.api_core.client_options as ClientOptions # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
+from google.api_core.client_options import ClientOptions
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
from google.auth import credentials as ga_credentials # type: ignore
from google.oauth2 import service_account # type: ignore
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.endpoint_service import pagers
@@ -34,6 +39,10 @@
from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint
from google.cloud.aiplatform_v1.types import endpoint_service
from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
@@ -43,7 +52,7 @@
class EndpointServiceAsyncClient:
- """"""
+ """A service for managing Vertex AI's Endpoints."""
_client: EndpointServiceClient
@@ -54,6 +63,14 @@ class EndpointServiceAsyncClient:
parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path)
model_path = staticmethod(EndpointServiceClient.model_path)
parse_model_path = staticmethod(EndpointServiceClient.parse_model_path)
+ model_deployment_monitoring_job_path = staticmethod(
+ EndpointServiceClient.model_deployment_monitoring_job_path
+ )
+ parse_model_deployment_monitoring_job_path = staticmethod(
+ EndpointServiceClient.parse_model_deployment_monitoring_job_path
+ )
+ network_path = staticmethod(EndpointServiceClient.network_path)
+ parse_network_path = staticmethod(EndpointServiceClient.parse_network_path)
common_billing_account_path = staticmethod(
EndpointServiceClient.common_billing_account_path
)
@@ -81,7 +98,8 @@ class EndpointServiceAsyncClient:
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
- """Creates an instance of this client using the provided credentials info.
+ """Creates an instance of this client using the provided credentials
+ info.
Args:
info (dict): The service account private key info.
@@ -96,7 +114,7 @@ def from_service_account_info(cls, info: dict, *args, **kwargs):
@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
- file.
+ file.
Args:
filename (str): The path to the service account private key json
@@ -111,9 +129,45 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):
from_service_account_json = from_service_account_file
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ return EndpointServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
+
@property
def transport(self) -> EndpointServiceTransport:
- """Return the transport used by the client instance.
+ """Returns the transport used by the client instance.
Returns:
EndpointServiceTransport: The transport used by the client instance.
@@ -132,7 +186,7 @@ def __init__(
client_options: ClientOptions = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiate the endpoint service client.
+ """Instantiates the endpoint service client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -173,18 +227,46 @@ def __init__(
async def create_endpoint(
self,
- request: endpoint_service.CreateEndpointRequest = None,
+ request: Union[endpoint_service.CreateEndpointRequest, dict] = None,
*,
parent: str = None,
endpoint: gca_endpoint.Endpoint = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ endpoint_id: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Creates an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_create_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ endpoint = aiplatform_v1.Endpoint()
+ endpoint.display_name = "display_name_value"
+
+ request = aiplatform_v1.CreateEndpointRequest(
+ parent="parent_value",
+ endpoint=endpoint,
+ )
+
+ # Make the request
+ operation = client.create_endpoint(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.CreateEndpointRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.CreateEndpointRequest, dict]):
The request object. Request message for
[EndpointService.CreateEndpoint][google.cloud.aiplatform.v1.EndpointService.CreateEndpoint].
parent (:class:`str`):
@@ -200,6 +282,21 @@ async def create_endpoint(
This corresponds to the ``endpoint`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
+ endpoint_id (:class:`str`):
+ Immutable. The ID to use for endpoint, which will become
+ the final component of the endpoint resource name. If
+ not provided, Vertex AI will generate a value for this
+ ID.
+
+ This value should be 1-10 characters, and valid
+ characters are /[0-9]/. When using HTTP/JSON, this field
+ is populated based on a query string argument, such as
+ ``?endpoint_id=12345``. This is the fallback for fields
+ that are not included in either the URI or the body.
+
+ This corresponds to the ``endpoint_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
@@ -215,9 +312,9 @@ async def create_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
- has_flattened_params = any([parent, endpoint])
+ has_flattened_params = any([parent, endpoint, endpoint_id])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
@@ -232,12 +329,14 @@ async def create_endpoint(
request.parent = parent
if endpoint is not None:
request.endpoint = endpoint
+ if endpoint_id is not None:
+ request.endpoint_id = endpoint_id
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.create_endpoint,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -248,7 +347,12 @@ async def create_endpoint(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -263,17 +367,36 @@ async def create_endpoint(
async def get_endpoint(
self,
- request: endpoint_service.GetEndpointRequest = None,
+ request: Union[endpoint_service.GetEndpointRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> endpoint.Endpoint:
r"""Gets an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetEndpointRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_endpoint(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.GetEndpointRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.GetEndpointRequest, dict]):
The request object. Request message for
[EndpointService.GetEndpoint][google.cloud.aiplatform.v1.EndpointService.GetEndpoint]
name (:class:`str`):
@@ -297,7 +420,7 @@ async def get_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -317,7 +440,7 @@ async def get_endpoint(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.get_endpoint,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -328,24 +451,49 @@ async def get_endpoint(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
async def list_endpoints(
self,
- request: endpoint_service.ListEndpointsRequest = None,
+ request: Union[endpoint_service.ListEndpointsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListEndpointsAsyncPager:
r"""Lists Endpoints in a Location.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_endpoints():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListEndpointsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_endpoints(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.ListEndpointsRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.ListEndpointsRequest, dict]):
The request object. Request message for
[EndpointService.ListEndpoints][google.cloud.aiplatform.v1.EndpointService.ListEndpoints].
parent (:class:`str`):
@@ -372,7 +520,7 @@ async def list_endpoints(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -392,7 +540,7 @@ async def list_endpoints(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.list_endpoints,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -403,12 +551,20 @@ async def list_endpoints(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__aiter__` convenience method.
response = pagers.ListEndpointsAsyncPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -416,18 +572,40 @@ async def list_endpoints(
async def update_endpoint(
self,
- request: endpoint_service.UpdateEndpointRequest = None,
+ request: Union[endpoint_service.UpdateEndpointRequest, dict] = None,
*,
endpoint: gca_endpoint.Endpoint = None,
update_mask: field_mask_pb2.FieldMask = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gca_endpoint.Endpoint:
r"""Updates an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_update_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ endpoint = aiplatform_v1.Endpoint()
+ endpoint.display_name = "display_name_value"
+
+ request = aiplatform_v1.UpdateEndpointRequest(
+ endpoint=endpoint,
+ )
+
+ # Make the request
+ response = await client.update_endpoint(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.UpdateEndpointRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.UpdateEndpointRequest, dict]):
The request object. Request message for
[EndpointService.UpdateEndpoint][google.cloud.aiplatform.v1.EndpointService.UpdateEndpoint].
endpoint (:class:`google.cloud.aiplatform_v1.types.Endpoint`):
@@ -439,7 +617,7 @@ async def update_endpoint(
should not be set.
update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
Required. The update mask applies to the resource. See
- `FieldMask `__.
+ [google.protobuf.FieldMask][google.protobuf.FieldMask].
This corresponds to the ``update_mask`` field
on the ``request`` instance; if ``request`` is provided, this
@@ -458,7 +636,7 @@ async def update_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, update_mask])
if request is not None and has_flattened_params:
@@ -480,7 +658,7 @@ async def update_endpoint(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.update_endpoint,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -493,24 +671,52 @@ async def update_endpoint(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
async def delete_endpoint(
self,
- request: endpoint_service.DeleteEndpointRequest = None,
+ request: Union[endpoint_service.DeleteEndpointRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Deletes an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_delete_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteEndpointRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_endpoint(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.DeleteEndpointRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.DeleteEndpointRequest, dict]):
The request object. Request message for
[EndpointService.DeleteEndpoint][google.cloud.aiplatform.v1.EndpointService.DeleteEndpoint].
name (:class:`str`):
@@ -547,7 +753,7 @@ async def delete_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -567,7 +773,7 @@ async def delete_endpoint(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.delete_endpoint,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -578,7 +784,12 @@ async def delete_endpoint(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -593,22 +804,48 @@ async def delete_endpoint(
async def deploy_model(
self,
- request: endpoint_service.DeployModelRequest = None,
+ request: Union[endpoint_service.DeployModelRequest, dict] = None,
*,
endpoint: str = None,
deployed_model: gca_endpoint.DeployedModel = None,
- traffic_split: Sequence[
- endpoint_service.DeployModelRequest.TrafficSplitEntry
- ] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ traffic_split: Mapping[str, int] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
r"""Deploys a Model into this Endpoint, creating a
DeployedModel within it.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_deploy_model():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ deployed_model = aiplatform_v1.DeployedModel()
+ deployed_model.dedicated_resources.min_replica_count = 1803
+ deployed_model.model = "model_value"
+
+ request = aiplatform_v1.DeployModelRequest(
+ endpoint="endpoint_value",
+ deployed_model=deployed_model,
+ )
+
+ # Make the request
+ operation = client.deploy_model(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.DeployModelRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.DeployModelRequest, dict]):
The request object. Request message for
[EndpointService.DeployModel][google.cloud.aiplatform.v1.EndpointService.DeployModel].
endpoint (:class:`str`):
@@ -630,7 +867,7 @@ async def deploy_model(
This corresponds to the ``deployed_model`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
- traffic_split (:class:`Sequence[google.cloud.aiplatform_v1.types.DeployModelRequest.TrafficSplitEntry]`):
+ traffic_split (:class:`Mapping[str, int]`):
A map from a DeployedModel's ID to the percentage of
this Endpoint's traffic that should be forwarded to that
DeployedModel.
@@ -667,7 +904,7 @@ async def deploy_model(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, deployed_model, traffic_split])
if request is not None and has_flattened_params:
@@ -692,7 +929,7 @@ async def deploy_model(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.deploy_model,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -703,7 +940,12 @@ async def deploy_model(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -718,14 +960,12 @@ async def deploy_model(
async def undeploy_model(
self,
- request: endpoint_service.UndeployModelRequest = None,
+ request: Union[endpoint_service.UndeployModelRequest, dict] = None,
*,
endpoint: str = None,
deployed_model_id: str = None,
- traffic_split: Sequence[
- endpoint_service.UndeployModelRequest.TrafficSplitEntry
- ] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ traffic_split: Mapping[str, int] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> operation_async.AsyncOperation:
@@ -733,8 +973,32 @@ async def undeploy_model(
DeployedModel from it, and freeing all resources it's
using.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_undeploy_model():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.UndeployModelRequest(
+ endpoint="endpoint_value",
+ deployed_model_id="deployed_model_id_value",
+ )
+
+ # Make the request
+ operation = client.undeploy_model(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (:class:`google.cloud.aiplatform_v1.types.UndeployModelRequest`):
+ request (Union[google.cloud.aiplatform_v1.types.UndeployModelRequest, dict]):
The request object. Request message for
[EndpointService.UndeployModel][google.cloud.aiplatform.v1.EndpointService.UndeployModel].
endpoint (:class:`str`):
@@ -752,7 +1016,7 @@ async def undeploy_model(
This corresponds to the ``deployed_model_id`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
- traffic_split (:class:`Sequence[google.cloud.aiplatform_v1.types.UndeployModelRequest.TrafficSplitEntry]`):
+ traffic_split (:class:`Mapping[str, int]`):
If this field is provided, then the Endpoint's
[traffic_split][google.cloud.aiplatform.v1.Endpoint.traffic_split]
will be overwritten with it. If last DeployedModel is
@@ -783,7 +1047,7 @@ async def undeploy_model(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, deployed_model_id, traffic_split])
if request is not None and has_flattened_params:
@@ -808,7 +1072,7 @@ async def undeploy_model(
# and friendly error handling.
rpc = gapic_v1.method_async.wrap_method(
self._client._transport.undeploy_model,
- default_timeout=5.0,
+ default_timeout=None,
client_info=DEFAULT_CLIENT_INFO,
)
@@ -819,7 +1083,12 @@ async def undeploy_model(
)
# Send the request.
- response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = operation_async.from_gapic(
@@ -832,6 +1101,683 @@ async def undeploy_model(
# Done; return the response.
return response
+ async def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.transport.close()
+
try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py
index 8bc3a8026f..b55399fdf8 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +14,26 @@
# limitations under the License.
#
from collections import OrderedDict
-from distutils import util
import os
import re
-from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
import pkg_resources
-from google.api_core import client_options as client_options_lib # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
+from google.api_core import client_options as client_options_lib
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport import mtls # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.auth.exceptions import MutualTLSChannelError # type: ignore
from google.oauth2 import service_account # type: ignore
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
from google.api_core import operation as gac_operation # type: ignore
from google.api_core import operation_async # type: ignore
from google.cloud.aiplatform_v1.services.endpoint_service import pagers
@@ -38,6 +42,10 @@
from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint
from google.cloud.aiplatform_v1.types import endpoint_service
from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.protobuf import empty_pb2 # type: ignore
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
@@ -60,8 +68,11 @@ class EndpointServiceClientMeta(type):
_transport_registry["grpc"] = EndpointServiceGrpcTransport
_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport
- def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]:
- """Return an appropriate transport class.
+ def get_transport_class(
+ cls,
+ label: str = None,
+ ) -> Type[EndpointServiceTransport]:
+ """Returns an appropriate transport class.
Args:
label: The name of the desired transport. If none is
@@ -80,11 +91,12 @@ def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTranspor
class EndpointServiceClient(metaclass=EndpointServiceClientMeta):
- """"""
+ """A service for managing Vertex AI's Endpoints."""
@staticmethod
def _get_default_mtls_endpoint(api_endpoint):
- """Convert api endpoint to mTLS endpoint.
+ """Converts api endpoint to mTLS endpoint.
+
Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
"*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
Args:
@@ -118,7 +130,8 @@ def _get_default_mtls_endpoint(api_endpoint):
@classmethod
def from_service_account_info(cls, info: dict, *args, **kwargs):
- """Creates an instance of this client using the provided credentials info.
+ """Creates an instance of this client using the provided credentials
+ info.
Args:
info (dict): The service account private key info.
@@ -135,7 +148,7 @@ def from_service_account_info(cls, info: dict, *args, **kwargs):
@classmethod
def from_service_account_file(cls, filename: str, *args, **kwargs):
"""Creates an instance of this client using the provided credentials
- file.
+ file.
Args:
filename (str): The path to the service account private key json
@@ -154,23 +167,30 @@ def from_service_account_file(cls, filename: str, *args, **kwargs):
@property
def transport(self) -> EndpointServiceTransport:
- """Return the transport used by the client instance.
+ """Returns the transport used by the client instance.
Returns:
- EndpointServiceTransport: The transport used by the client instance.
+ EndpointServiceTransport: The transport used by the client
+ instance.
"""
return self._transport
@staticmethod
- def endpoint_path(project: str, location: str, endpoint: str,) -> str:
- """Return a fully-qualified endpoint string."""
+ def endpoint_path(
+ project: str,
+ location: str,
+ endpoint: str,
+ ) -> str:
+ """Returns a fully-qualified endpoint string."""
return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(
- project=project, location=location, endpoint=endpoint,
+ project=project,
+ location=location,
+ endpoint=endpoint,
)
@staticmethod
def parse_endpoint_path(path: str) -> Dict[str, str]:
- """Parse a endpoint path into its component segments."""
+ """Parses a endpoint path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$",
path,
@@ -178,15 +198,21 @@ def parse_endpoint_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def model_path(project: str, location: str, model: str,) -> str:
- """Return a fully-qualified model string."""
+ def model_path(
+ project: str,
+ location: str,
+ model: str,
+ ) -> str:
+ """Returns a fully-qualified model string."""
return "projects/{project}/locations/{location}/models/{model}".format(
- project=project, location=location, model=model,
+ project=project,
+ location=location,
+ model=model,
)
@staticmethod
def parse_model_path(path: str) -> Dict[str, str]:
- """Parse a model path into its component segments."""
+ """Parses a model path into its component segments."""
m = re.match(
r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$",
path,
@@ -194,8 +220,51 @@ def parse_model_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_billing_account_path(billing_account: str,) -> str:
- """Return a fully-qualified billing_account string."""
+ def model_deployment_monitoring_job_path(
+ project: str,
+ location: str,
+ model_deployment_monitoring_job: str,
+ ) -> str:
+ """Returns a fully-qualified model_deployment_monitoring_job string."""
+ return "projects/{project}/locations/{location}/modelDeploymentMonitoringJobs/{model_deployment_monitoring_job}".format(
+ project=project,
+ location=location,
+ model_deployment_monitoring_job=model_deployment_monitoring_job,
+ )
+
+ @staticmethod
+ def parse_model_deployment_monitoring_job_path(path: str) -> Dict[str, str]:
+ """Parses a model_deployment_monitoring_job path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/modelDeploymentMonitoringJobs/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def network_path(
+ project: str,
+ network: str,
+ ) -> str:
+ """Returns a fully-qualified network string."""
+ return "projects/{project}/global/networks/{network}".format(
+ project=project,
+ network=network,
+ )
+
+ @staticmethod
+ def parse_network_path(path: str) -> Dict[str, str]:
+ """Parses a network path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/global/networks/(?P.+?)$", path
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_billing_account_path(
+ billing_account: str,
+ ) -> str:
+ """Returns a fully-qualified billing_account string."""
return "billingAccounts/{billing_account}".format(
billing_account=billing_account,
)
@@ -207,9 +276,13 @@ def parse_common_billing_account_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_folder_path(folder: str,) -> str:
- """Return a fully-qualified folder string."""
- return "folders/{folder}".format(folder=folder,)
+ def common_folder_path(
+ folder: str,
+ ) -> str:
+ """Returns a fully-qualified folder string."""
+ return "folders/{folder}".format(
+ folder=folder,
+ )
@staticmethod
def parse_common_folder_path(path: str) -> Dict[str, str]:
@@ -218,9 +291,13 @@ def parse_common_folder_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_organization_path(organization: str,) -> str:
- """Return a fully-qualified organization string."""
- return "organizations/{organization}".format(organization=organization,)
+ def common_organization_path(
+ organization: str,
+ ) -> str:
+ """Returns a fully-qualified organization string."""
+ return "organizations/{organization}".format(
+ organization=organization,
+ )
@staticmethod
def parse_common_organization_path(path: str) -> Dict[str, str]:
@@ -229,9 +306,13 @@ def parse_common_organization_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_project_path(project: str,) -> str:
- """Return a fully-qualified project string."""
- return "projects/{project}".format(project=project,)
+ def common_project_path(
+ project: str,
+ ) -> str:
+ """Returns a fully-qualified project string."""
+ return "projects/{project}".format(
+ project=project,
+ )
@staticmethod
def parse_common_project_path(path: str) -> Dict[str, str]:
@@ -240,10 +321,14 @@ def parse_common_project_path(path: str) -> Dict[str, str]:
return m.groupdict() if m else {}
@staticmethod
- def common_location_path(project: str, location: str,) -> str:
- """Return a fully-qualified location string."""
+ def common_location_path(
+ project: str,
+ location: str,
+ ) -> str:
+ """Returns a fully-qualified location string."""
return "projects/{project}/locations/{location}".format(
- project=project, location=location,
+ project=project,
+ location=location,
)
@staticmethod
@@ -252,6 +337,73 @@ def parse_common_location_path(path: str) -> Dict[str, str]:
m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path)
return m.groupdict() if m else {}
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[client_options_lib.ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ if client_options is None:
+ client_options = client_options_lib.ClientOptions()
+ use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Figure out the client cert source to use.
+ client_cert_source = None
+ if use_client_cert == "true":
+ if client_options.client_cert_source:
+ client_cert_source = client_options.client_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+
+ # Figure out which api endpoint to use.
+ if client_options.api_endpoint is not None:
+ api_endpoint = client_options.api_endpoint
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = cls.DEFAULT_ENDPOINT
+
+ return api_endpoint, client_cert_source
+
def __init__(
self,
*,
@@ -260,7 +412,7 @@ def __init__(
client_options: Optional[client_options_lib.ClientOptions] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
- """Instantiate the endpoint service client.
+ """Instantiates the endpoint service client.
Args:
credentials (Optional[google.auth.credentials.Credentials]): The
@@ -302,58 +454,42 @@ def __init__(
if client_options is None:
client_options = client_options_lib.ClientOptions()
- # Create SSL credentials for mutual TLS if needed.
- use_client_cert = bool(
- util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))
+ api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
+ client_options
)
- client_cert_source_func = None
- is_mtls = False
- if use_client_cert:
- if client_options.client_cert_source:
- is_mtls = True
- client_cert_source_func = client_options.client_cert_source
- else:
- is_mtls = mtls.has_default_client_cert_source()
- client_cert_source_func = (
- mtls.default_client_cert_source() if is_mtls else None
- )
-
- # Figure out which api endpoint to use.
- if client_options.api_endpoint is not None:
- api_endpoint = client_options.api_endpoint
- else:
- use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
- if use_mtls_env == "never":
- api_endpoint = self.DEFAULT_ENDPOINT
- elif use_mtls_env == "always":
- api_endpoint = self.DEFAULT_MTLS_ENDPOINT
- elif use_mtls_env == "auto":
- api_endpoint = (
- self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT
- )
- else:
- raise MutualTLSChannelError(
- "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always"
- )
+ api_key_value = getattr(client_options, "api_key", None)
+ if api_key_value and credentials:
+ raise ValueError(
+ "client_options.api_key and credentials are mutually exclusive"
+ )
# Save or instantiate the transport.
# Ordinarily, we provide the transport, but allowing a custom transport
# instance provides an extensibility point for unusual situations.
if isinstance(transport, EndpointServiceTransport):
# transport is a EndpointServiceTransport instance.
- if credentials or client_options.credentials_file:
+ if credentials or client_options.credentials_file or api_key_value:
raise ValueError(
"When providing a transport instance, "
"provide its credentials directly."
)
if client_options.scopes:
raise ValueError(
- "When providing a transport instance, "
- "provide its scopes directly."
+ "When providing a transport instance, provide its scopes "
+ "directly."
)
self._transport = transport
else:
+ import google.auth._default # type: ignore
+
+ if api_key_value and hasattr(
+ google.auth._default, "get_api_key_credentials"
+ ):
+ credentials = google.auth._default.get_api_key_credentials(
+ api_key_value
+ )
+
Transport = type(self).get_transport_class(transport)
self._transport = Transport(
credentials=credentials,
@@ -363,22 +499,51 @@ def __init__(
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
+ always_use_jwt_access=True,
)
def create_endpoint(
self,
- request: endpoint_service.CreateEndpointRequest = None,
+ request: Union[endpoint_service.CreateEndpointRequest, dict] = None,
*,
parent: str = None,
endpoint: gca_endpoint.Endpoint = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ endpoint_id: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Creates an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_create_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ endpoint = aiplatform_v1.Endpoint()
+ endpoint.display_name = "display_name_value"
+
+ request = aiplatform_v1.CreateEndpointRequest(
+ parent="parent_value",
+ endpoint=endpoint,
+ )
+
+ # Make the request
+ operation = client.create_endpoint(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.CreateEndpointRequest):
+ request (Union[google.cloud.aiplatform_v1.types.CreateEndpointRequest, dict]):
The request object. Request message for
[EndpointService.CreateEndpoint][google.cloud.aiplatform.v1.EndpointService.CreateEndpoint].
parent (str):
@@ -394,6 +559,21 @@ def create_endpoint(
This corresponds to the ``endpoint`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
+ endpoint_id (str):
+ Immutable. The ID to use for endpoint, which will become
+ the final component of the endpoint resource name. If
+ not provided, Vertex AI will generate a value for this
+ ID.
+
+ This value should be 1-10 characters, and valid
+ characters are /[0-9]/. When using HTTP/JSON, this field
+ is populated based on a query string argument, such as
+ ``?endpoint_id=12345``. This is the fallback for fields
+ that are not included in either the URI or the body.
+
+ This corresponds to the ``endpoint_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
@@ -409,9 +589,9 @@ def create_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
- has_flattened_params = any([parent, endpoint])
+ has_flattened_params = any([parent, endpoint, endpoint_id])
if request is not None and has_flattened_params:
raise ValueError(
"If the `request` argument is set, then none of "
@@ -430,6 +610,8 @@ def create_endpoint(
request.parent = parent
if endpoint is not None:
request.endpoint = endpoint
+ if endpoint_id is not None:
+ request.endpoint_id = endpoint_id
# Wrap the RPC method; this adds retry and timeout information,
# and friendly error handling.
@@ -442,7 +624,12 @@ def create_endpoint(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -457,17 +644,36 @@ def create_endpoint(
def get_endpoint(
self,
- request: endpoint_service.GetEndpointRequest = None,
+ request: Union[endpoint_service.GetEndpointRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> endpoint.Endpoint:
r"""Gets an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_get_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetEndpointRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = client.get_endpoint(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.GetEndpointRequest):
+ request (Union[google.cloud.aiplatform_v1.types.GetEndpointRequest, dict]):
The request object. Request message for
[EndpointService.GetEndpoint][google.cloud.aiplatform.v1.EndpointService.GetEndpoint]
name (str):
@@ -491,7 +697,7 @@ def get_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -522,24 +728,49 @@ def get_endpoint(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
def list_endpoints(
self,
- request: endpoint_service.ListEndpointsRequest = None,
+ request: Union[endpoint_service.ListEndpointsRequest, dict] = None,
*,
parent: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> pagers.ListEndpointsPager:
r"""Lists Endpoints in a Location.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_list_endpoints():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListEndpointsRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_endpoints(request=request)
+
+ # Handle the response
+ for response in page_result:
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.ListEndpointsRequest):
+ request (Union[google.cloud.aiplatform_v1.types.ListEndpointsRequest, dict]):
The request object. Request message for
[EndpointService.ListEndpoints][google.cloud.aiplatform.v1.EndpointService.ListEndpoints].
parent (str):
@@ -566,7 +797,7 @@ def list_endpoints(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([parent])
if request is not None and has_flattened_params:
@@ -597,12 +828,20 @@ def list_endpoints(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# This method is paged; wrap the response in a pager, which provides
# an `__iter__` convenience method.
response = pagers.ListEndpointsPager(
- method=rpc, request=request, response=response, metadata=metadata,
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
)
# Done; return the response.
@@ -610,18 +849,40 @@ def list_endpoints(
def update_endpoint(
self,
- request: endpoint_service.UpdateEndpointRequest = None,
+ request: Union[endpoint_service.UpdateEndpointRequest, dict] = None,
*,
endpoint: gca_endpoint.Endpoint = None,
update_mask: field_mask_pb2.FieldMask = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gca_endpoint.Endpoint:
r"""Updates an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_update_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ endpoint = aiplatform_v1.Endpoint()
+ endpoint.display_name = "display_name_value"
+
+ request = aiplatform_v1.UpdateEndpointRequest(
+ endpoint=endpoint,
+ )
+
+ # Make the request
+ response = client.update_endpoint(request=request)
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.UpdateEndpointRequest):
+ request (Union[google.cloud.aiplatform_v1.types.UpdateEndpointRequest, dict]):
The request object. Request message for
[EndpointService.UpdateEndpoint][google.cloud.aiplatform.v1.EndpointService.UpdateEndpoint].
endpoint (google.cloud.aiplatform_v1.types.Endpoint):
@@ -633,7 +894,7 @@ def update_endpoint(
should not be set.
update_mask (google.protobuf.field_mask_pb2.FieldMask):
Required. The update mask applies to the resource. See
- `FieldMask `__.
+ [google.protobuf.FieldMask][google.protobuf.FieldMask].
This corresponds to the ``update_mask`` field
on the ``request`` instance; if ``request`` is provided, this
@@ -652,7 +913,7 @@ def update_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, update_mask])
if request is not None and has_flattened_params:
@@ -687,24 +948,52 @@ def update_endpoint(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Done; return the response.
return response
def delete_endpoint(
self,
- request: endpoint_service.DeleteEndpointRequest = None,
+ request: Union[endpoint_service.DeleteEndpointRequest, dict] = None,
*,
name: str = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Deletes an Endpoint.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_delete_endpoint():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteEndpointRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_endpoint(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.DeleteEndpointRequest):
+ request (Union[google.cloud.aiplatform_v1.types.DeleteEndpointRequest, dict]):
The request object. Request message for
[EndpointService.DeleteEndpoint][google.cloud.aiplatform.v1.EndpointService.DeleteEndpoint].
name (str):
@@ -741,7 +1030,7 @@ def delete_endpoint(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([name])
if request is not None and has_flattened_params:
@@ -772,7 +1061,12 @@ def delete_endpoint(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -787,22 +1081,48 @@ def delete_endpoint(
def deploy_model(
self,
- request: endpoint_service.DeployModelRequest = None,
+ request: Union[endpoint_service.DeployModelRequest, dict] = None,
*,
endpoint: str = None,
deployed_model: gca_endpoint.DeployedModel = None,
- traffic_split: Sequence[
- endpoint_service.DeployModelRequest.TrafficSplitEntry
- ] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ traffic_split: Mapping[str, int] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
r"""Deploys a Model into this Endpoint, creating a
DeployedModel within it.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_deploy_model():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ deployed_model = aiplatform_v1.DeployedModel()
+ deployed_model.dedicated_resources.min_replica_count = 1803
+ deployed_model.model = "model_value"
+
+ request = aiplatform_v1.DeployModelRequest(
+ endpoint="endpoint_value",
+ deployed_model=deployed_model,
+ )
+
+ # Make the request
+ operation = client.deploy_model(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.DeployModelRequest):
+ request (Union[google.cloud.aiplatform_v1.types.DeployModelRequest, dict]):
The request object. Request message for
[EndpointService.DeployModel][google.cloud.aiplatform.v1.EndpointService.DeployModel].
endpoint (str):
@@ -824,7 +1144,7 @@ def deploy_model(
This corresponds to the ``deployed_model`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
- traffic_split (Sequence[google.cloud.aiplatform_v1.types.DeployModelRequest.TrafficSplitEntry]):
+ traffic_split (Mapping[str, int]):
A map from a DeployedModel's ID to the percentage of
this Endpoint's traffic that should be forwarded to that
DeployedModel.
@@ -861,7 +1181,7 @@ def deploy_model(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, deployed_model, traffic_split])
if request is not None and has_flattened_params:
@@ -896,7 +1216,12 @@ def deploy_model(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -911,14 +1236,12 @@ def deploy_model(
def undeploy_model(
self,
- request: endpoint_service.UndeployModelRequest = None,
+ request: Union[endpoint_service.UndeployModelRequest, dict] = None,
*,
endpoint: str = None,
deployed_model_id: str = None,
- traffic_split: Sequence[
- endpoint_service.UndeployModelRequest.TrafficSplitEntry
- ] = None,
- retry: retries.Retry = gapic_v1.method.DEFAULT,
+ traffic_split: Mapping[str, int] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: float = None,
metadata: Sequence[Tuple[str, str]] = (),
) -> gac_operation.Operation:
@@ -926,8 +1249,32 @@ def undeploy_model(
DeployedModel from it, and freeing all resources it's
using.
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_undeploy_model():
+ # Create a client
+ client = aiplatform_v1.EndpointServiceClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.UndeployModelRequest(
+ endpoint="endpoint_value",
+ deployed_model_id="deployed_model_id_value",
+ )
+
+ # Make the request
+ operation = client.undeploy_model(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = operation.result()
+
+ # Handle the response
+ print(response)
+
Args:
- request (google.cloud.aiplatform_v1.types.UndeployModelRequest):
+ request (Union[google.cloud.aiplatform_v1.types.UndeployModelRequest, dict]):
The request object. Request message for
[EndpointService.UndeployModel][google.cloud.aiplatform.v1.EndpointService.UndeployModel].
endpoint (str):
@@ -945,7 +1292,7 @@ def undeploy_model(
This corresponds to the ``deployed_model_id`` field
on the ``request`` instance; if ``request`` is provided, this
should not be set.
- traffic_split (Sequence[google.cloud.aiplatform_v1.types.UndeployModelRequest.TrafficSplitEntry]):
+ traffic_split (Mapping[str, int]):
If this field is provided, then the Endpoint's
[traffic_split][google.cloud.aiplatform.v1.Endpoint.traffic_split]
will be overwritten with it. If last DeployedModel is
@@ -976,7 +1323,7 @@ def undeploy_model(
"""
# Create or coerce a protobuf request object.
- # Sanity check: If we got a request object, we should *not* have
+ # Quick check: If we got a request object, we should *not* have
# gotten any keyword arguments that map to the request.
has_flattened_params = any([endpoint, deployed_model_id, traffic_split])
if request is not None and has_flattened_params:
@@ -1011,7 +1358,12 @@ def undeploy_model(
)
# Send the request.
- response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,)
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
# Wrap the response in an operation future.
response = gac_operation.from_gapic(
@@ -1024,6 +1376,690 @@ def undeploy_model(
# Done; return the response.
return response
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """Releases underlying transport's resources.
+
+ .. warning::
+ ONLY use as a context manager if the transport is NOT shared
+ with other clients! Exiting the with block will CLOSE the transport
+ and may cause errors in other clients!
+ """
+ self.transport.close()
+
+ def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
try:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py
index 0b222aee01..4b65110ad5 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,13 @@
#
from typing import (
Any,
- AsyncIterable,
+ AsyncIterator,
Awaitable,
Callable,
- Iterable,
Sequence,
Tuple,
Optional,
+ Iterator,
)
from google.cloud.aiplatform_v1.types import endpoint
@@ -75,14 +75,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- def pages(self) -> Iterable[endpoint_service.ListEndpointsResponse]:
+ def pages(self) -> Iterator[endpoint_service.ListEndpointsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = self._method(self._request, metadata=self._metadata)
yield self._response
- def __iter__(self) -> Iterable[endpoint.Endpoint]:
+ def __iter__(self) -> Iterator[endpoint.Endpoint]:
for page in self.pages:
yield from page.endpoints
@@ -116,7 +116,7 @@ def __init__(
*,
metadata: Sequence[Tuple[str, str]] = ()
):
- """Instantiate the pager.
+ """Instantiates the pager.
Args:
method (Callable): The method that was originally called, and
@@ -137,14 +137,14 @@ def __getattr__(self, name: str) -> Any:
return getattr(self._response, name)
@property
- async def pages(self) -> AsyncIterable[endpoint_service.ListEndpointsResponse]:
+ async def pages(self) -> AsyncIterator[endpoint_service.ListEndpointsResponse]:
yield self._response
while self._response.next_page_token:
self._request.page_token = self._response.next_page_token
self._response = await self._method(self._request, metadata=self._metadata)
yield self._response
- def __aiter__(self) -> AsyncIterable[endpoint.Endpoint]:
+ def __aiter__(self) -> AsyncIterator[endpoint.Endpoint]:
async def async_generator():
async for page in self.pages:
for response in page.endpoints:
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py
index 4d336c5875..92f3485150 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py
index a760eddfef..38d4ad241c 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,20 +15,24 @@
#
import abc
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
-import packaging.version
import pkg_resources
import google.auth # type: ignore
-import google.api_core # type: ignore
-from google.api_core import exceptions as core_exceptions # type: ignore
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import retry as retries # type: ignore
-from google.api_core import operations_v1 # type: ignore
+import google.api_core
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
from google.cloud.aiplatform_v1.types import endpoint
from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint
from google.cloud.aiplatform_v1.types import endpoint_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
try:
@@ -40,17 +44,6 @@
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
-try:
- # google.auth.__version__ was added in 1.26.0
- _GOOGLE_AUTH_VERSION = google.auth.__version__
-except AttributeError:
- try: # try pkg_resources if it is available
- _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
- except pkg_resources.DistributionNotFound: # pragma: NO COVER
- _GOOGLE_AUTH_VERSION = None
-
-_API_CORE_VERSION = google.api_core.__version__
-
class EndpointServiceTransport(abc.ABC):
"""Abstract transport class for EndpointService."""
@@ -68,6 +61,7 @@ def __init__(
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
**kwargs,
) -> None:
"""Instantiate the transport.
@@ -91,16 +85,19 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
"""
+
# Save the hostname. Default to port 443 (HTTPS) if none is specified.
if ":" not in host:
host += ":443"
self._host = host
- scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)
+ scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
# Save the scopes.
- self._scopes = scopes or self.AUTH_SCOPES
+ self._scopes = scopes
# If no credentials are provided, then determine the appropriate
# defaults.
@@ -113,88 +110,73 @@ def __init__(
credentials, _ = google.auth.load_credentials_from_file(
credentials_file, **scopes_kwargs, quota_project_id=quota_project_id
)
-
elif credentials is None:
credentials, _ = google.auth.default(
**scopes_kwargs, quota_project_id=quota_project_id
)
- # Save the credentials.
- self._credentials = credentials
-
- # TODO(busunkim): These two class methods are in the base transport
- # to avoid duplicating code across the transport classes. These functions
- # should be deleted once the minimum required versions of google-api-core
- # and google-auth are increased.
-
- # TODO: Remove this function once google-auth >= 1.25.0 is required
- @classmethod
- def _get_scopes_kwargs(
- cls, host: str, scopes: Optional[Sequence[str]]
- ) -> Dict[str, Optional[Sequence[str]]]:
- """Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""
-
- scopes_kwargs = {}
-
- if _GOOGLE_AUTH_VERSION and (
- packaging.version.parse(_GOOGLE_AUTH_VERSION)
- >= packaging.version.parse("1.25.0")
+ # If the credentials are service account credentials, then always try to use self signed JWT.
+ if (
+ always_use_jwt_access
+ and isinstance(credentials, service_account.Credentials)
+ and hasattr(service_account.Credentials, "with_always_use_jwt_access")
):
- scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
- else:
- scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}
-
- return scopes_kwargs
+ credentials = credentials.with_always_use_jwt_access(True)
- # TODO: Remove this function once google-api-core >= 1.26.0 is required
- @classmethod
- def _get_self_signed_jwt_kwargs(
- cls, host: str, scopes: Optional[Sequence[str]]
- ) -> Dict[str, Union[Optional[Sequence[str]], str]]:
- """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""
-
- self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}
-
- if _API_CORE_VERSION and (
- packaging.version.parse(_API_CORE_VERSION)
- >= packaging.version.parse("1.26.0")
- ):
- self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
- self_signed_jwt_kwargs["scopes"] = scopes
- self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
- else:
- self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES
-
- return self_signed_jwt_kwargs
+ # Save the credentials.
+ self._credentials = credentials
def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
self.create_endpoint: gapic_v1.method.wrap_method(
- self.create_endpoint, default_timeout=5.0, client_info=client_info,
+ self.create_endpoint,
+ default_timeout=None,
+ client_info=client_info,
),
self.get_endpoint: gapic_v1.method.wrap_method(
- self.get_endpoint, default_timeout=5.0, client_info=client_info,
+ self.get_endpoint,
+ default_timeout=None,
+ client_info=client_info,
),
self.list_endpoints: gapic_v1.method.wrap_method(
- self.list_endpoints, default_timeout=5.0, client_info=client_info,
+ self.list_endpoints,
+ default_timeout=None,
+ client_info=client_info,
),
self.update_endpoint: gapic_v1.method.wrap_method(
- self.update_endpoint, default_timeout=5.0, client_info=client_info,
+ self.update_endpoint,
+ default_timeout=None,
+ client_info=client_info,
),
self.delete_endpoint: gapic_v1.method.wrap_method(
- self.delete_endpoint, default_timeout=5.0, client_info=client_info,
+ self.delete_endpoint,
+ default_timeout=None,
+ client_info=client_info,
),
self.deploy_model: gapic_v1.method.wrap_method(
- self.deploy_model, default_timeout=5.0, client_info=client_info,
+ self.deploy_model,
+ default_timeout=None,
+ client_info=client_info,
),
self.undeploy_model: gapic_v1.method.wrap_method(
- self.undeploy_model, default_timeout=5.0, client_info=client_info,
+ self.undeploy_model,
+ default_timeout=None,
+ client_info=client_info,
),
}
+ def close(self):
+ """Closes resources associated with the transport.
+
+ .. warning::
+ Only call this method if the transport is NOT shared
+ with other clients - this may cause errors in other clients!
+ """
+ raise NotImplementedError()
+
@property
- def operations_client(self) -> operations_v1.OperationsClient:
+ def operations_client(self):
"""Return the client designed to process long-running operations."""
raise NotImplementedError()
@@ -264,5 +246,102 @@ def undeploy_model(
]:
raise NotImplementedError()
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest],
+ Union[
+ operations_pb2.ListOperationsResponse,
+ Awaitable[operations_pb2.ListOperationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.GetOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.WaitOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.SetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.GetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ Union[
+ iam_policy_pb2.TestIamPermissionsResponse,
+ Awaitable[iam_policy_pb2.TestIamPermissionsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[
+ [locations_pb2.GetLocationRequest],
+ Union[locations_pb2.Location, Awaitable[locations_pb2.Location]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest],
+ Union[
+ locations_pb2.ListLocationsResponse,
+ Awaitable[locations_pb2.ListLocationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def kind(self) -> str:
+ raise NotImplementedError()
+
__all__ = ("EndpointServiceTransport",)
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py
index d81853d560..f6de887c2a 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +16,9 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
-from google.api_core import grpc_helpers # type: ignore
-from google.api_core import operations_v1 # type: ignore
-from google.api_core import gapic_v1 # type: ignore
+from google.api_core import grpc_helpers
+from google.api_core import operations_v1
+from google.api_core import gapic_v1
import google.auth # type: ignore
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
@@ -28,6 +28,10 @@
from google.cloud.aiplatform_v1.types import endpoint
from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint
from google.cloud.aiplatform_v1.types import endpoint_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO
@@ -35,6 +39,8 @@
class EndpointServiceGrpcTransport(EndpointServiceTransport):
"""gRPC backend transport for EndpointService.
+ A service for managing Vertex AI's Endpoints.
+
This class defines the same methods as the primary client, so the
primary client can load the underlying transport implementation
and call it.
@@ -59,6 +65,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.
@@ -81,16 +88,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
- ``client_cert_source`` or applicatin default SSL credentials.
+ ``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
- for grpc channel. It is ignored if ``channel`` is provided.
+ for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
- both in PEM format. It is used to configure mutual TLS channel. It is
+ both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
@@ -99,6 +106,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
Raises:
google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
@@ -109,7 +118,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
- self._operations_client = None
+ self._operations_client: Optional[operations_v1.OperationsClient] = None
if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
@@ -152,13 +161,17 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
)
if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
self._host,
+ # use the credentials which are saved
credentials=self._credentials,
- credentials_file=credentials_file,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
@@ -207,21 +220,20 @@ def create_channel(
and ``credentials_file`` are passed.
"""
- self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)
-
return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
- **self_signed_jwt_kwargs,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
**kwargs,
)
@property
def grpc_channel(self) -> grpc.Channel:
- """Return the channel designed to connect to this service.
- """
+ """Return the channel designed to connect to this service."""
return self._grpc_channel
@property
@@ -231,7 +243,7 @@ def operations_client(self) -> operations_v1.OperationsClient:
This property caches on the instance; repeated calls return the same
client.
"""
- # Sanity check: Only create a new client if we do not already have one.
+ # Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsClient(self.grpc_channel)
@@ -425,5 +437,215 @@ def undeploy_model(
)
return self._stubs["undeploy_model"]
+ def close(self):
+ self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
+ @property
+ def kind(self) -> str:
+ return "grpc"
+
__all__ = ("EndpointServiceGrpcTransport",)
diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py
index 41f295e135..f2c05e7bbc 100644
--- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 Google LLC
+# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,12 +16,11 @@
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
-from google.api_core import gapic_v1 # type: ignore
-from google.api_core import grpc_helpers_async # type: ignore
-from google.api_core import operations_v1 # type: ignore
+from google.api_core import gapic_v1
+from google.api_core import grpc_helpers_async
+from google.api_core import operations_v1
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
-import packaging.version
import grpc # type: ignore
from grpc.experimental import aio # type: ignore
@@ -29,6 +28,10 @@
from google.cloud.aiplatform_v1.types import endpoint
from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint
from google.cloud.aiplatform_v1.types import endpoint_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
from google.longrunning import operations_pb2 # type: ignore
from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO
from .grpc import EndpointServiceGrpcTransport
@@ -37,6 +40,8 @@
class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport):
"""gRPC AsyncIO backend transport for EndpointService.
+ A service for managing Vertex AI's Endpoints.
+
This class defines the same methods as the primary client, so the
primary client can load the underlying transport implementation
and call it.
@@ -80,14 +85,14 @@ def create_channel(
aio.Channel: A gRPC AsyncIO channel object.
"""
- self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)
-
return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
quota_project_id=quota_project_id,
- **self_signed_jwt_kwargs,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
**kwargs,
)
@@ -105,6 +110,7 @@ def __init__(
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
) -> None:
"""Instantiate the transport.
@@ -128,16 +134,16 @@ def __init__(
api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
If provided, it overrides the ``host`` argument and tries to create
a mutual TLS channel with client SSL credentials from
- ``client_cert_source`` or applicatin default SSL credentials.
+ ``client_cert_source`` or application default SSL credentials.
client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
Deprecated. A callback to provide client SSL certificate bytes and
private key bytes, both in PEM format. It is ignored if
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
- for grpc channel. It is ignored if ``channel`` is provided.
+ for the grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
- both in PEM format. It is used to configure mutual TLS channel. It is
+ both in PEM format. It is used to configure a mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
@@ -146,6 +152,8 @@ def __init__(
API requests. If ``None``, then default info will be used.
Generally, you only need to set this if you're developing
your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
Raises:
google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
@@ -156,7 +164,7 @@ def __init__(
self._grpc_channel = None
self._ssl_channel_credentials = ssl_channel_credentials
self._stubs: Dict[str, Callable] = {}
- self._operations_client = None
+ self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None
if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
@@ -198,13 +206,17 @@ def __init__(
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
)
if not self._grpc_channel:
self._grpc_channel = type(self).create_channel(
self._host,
+ # use the credentials which are saved
credentials=self._credentials,
- credentials_file=credentials_file,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
scopes=self._scopes,
ssl_credentials=self._ssl_channel_credentials,
quota_project_id=quota_project_id,
@@ -234,7 +246,7 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient:
This property caches on the instance; repeated calls return the same
client.
"""
- # Sanity check: Only create a new client if we do not already have one.
+ # Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsAsyncClient(
self.grpc_channel
@@ -441,5 +453,211 @@ def undeploy_model(
)
return self._stubs["undeploy_model"]
+ def close(self):
+ return self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
__all__ = ("EndpointServiceGrpcAsyncIOTransport",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/__init__.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/__init__.py
new file mode 100644
index 0000000000..a2f6b5fa66
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from .client import FeaturestoreOnlineServingServiceClient
+from .async_client import FeaturestoreOnlineServingServiceAsyncClient
+
+__all__ = (
+ "FeaturestoreOnlineServingServiceClient",
+ "FeaturestoreOnlineServingServiceAsyncClient",
+)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py
new file mode 100644
index 0000000000..37cd57510a
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/async_client.py
@@ -0,0 +1,1148 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+import functools
+import re
+from typing import (
+ Dict,
+ Mapping,
+ Optional,
+ AsyncIterable,
+ Awaitable,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+)
+import pkg_resources
+
+from google.api_core.client_options import ClientOptions
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
+from google.cloud.aiplatform_v1.types import featurestore_online_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from .transports.base import (
+ FeaturestoreOnlineServingServiceTransport,
+ DEFAULT_CLIENT_INFO,
+)
+from .transports.grpc_asyncio import (
+ FeaturestoreOnlineServingServiceGrpcAsyncIOTransport,
+)
+from .client import FeaturestoreOnlineServingServiceClient
+
+
+class FeaturestoreOnlineServingServiceAsyncClient:
+ """A service for serving online feature values."""
+
+ _client: FeaturestoreOnlineServingServiceClient
+
+ DEFAULT_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_ENDPOINT
+ DEFAULT_MTLS_ENDPOINT = FeaturestoreOnlineServingServiceClient.DEFAULT_MTLS_ENDPOINT
+
+ entity_type_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.entity_type_path
+ )
+ parse_entity_type_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_entity_type_path
+ )
+ common_billing_account_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.common_billing_account_path
+ )
+ parse_common_billing_account_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_common_billing_account_path
+ )
+ common_folder_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.common_folder_path
+ )
+ parse_common_folder_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_common_folder_path
+ )
+ common_organization_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.common_organization_path
+ )
+ parse_common_organization_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_common_organization_path
+ )
+ common_project_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.common_project_path
+ )
+ parse_common_project_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_common_project_path
+ )
+ common_location_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.common_location_path
+ )
+ parse_common_location_path = staticmethod(
+ FeaturestoreOnlineServingServiceClient.parse_common_location_path
+ )
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreOnlineServingServiceAsyncClient: The constructed client.
+ """
+ return FeaturestoreOnlineServingServiceClient.from_service_account_info.__func__(FeaturestoreOnlineServingServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreOnlineServingServiceAsyncClient: The constructed client.
+ """
+ return FeaturestoreOnlineServingServiceClient.from_service_account_file.__func__(FeaturestoreOnlineServingServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
+ from_service_account_json = from_service_account_file
+
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ return FeaturestoreOnlineServingServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
+
+ @property
+ def transport(self) -> FeaturestoreOnlineServingServiceTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ FeaturestoreOnlineServingServiceTransport: The transport used by the client instance.
+ """
+ return self._client.transport
+
+ get_transport_class = functools.partial(
+ type(FeaturestoreOnlineServingServiceClient).get_transport_class,
+ type(FeaturestoreOnlineServingServiceClient),
+ )
+
+ def __init__(
+ self,
+ *,
+ credentials: ga_credentials.Credentials = None,
+ transport: Union[
+ str, FeaturestoreOnlineServingServiceTransport
+ ] = "grpc_asyncio",
+ client_options: ClientOptions = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ ) -> None:
+ """Instantiates the featurestore online serving service client.
+
+ Args:
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ transport (Union[str, ~.FeaturestoreOnlineServingServiceTransport]): The
+ transport to use. If set to None, a transport is chosen
+ automatically.
+ client_options (ClientOptions): Custom options for the client. It
+ won't take effect if a ``transport`` instance is provided.
+ (1) The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
+ environment variable can also be used to override the endpoint:
+ "always" (always use the default mTLS endpoint), "never" (always
+ use the default regular endpoint) and "auto" (auto switch to the
+ default mTLS endpoint if client certificate is present, this is
+ the default value). However, the ``api_endpoint`` property takes
+ precedence if provided.
+ (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ is "true", then the ``client_cert_source`` property can be used
+ to provide client certificate for mutual TLS transport. If
+ not provided, the default SSL client certificate will be used if
+ present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
+ set, no client certificate will be used.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ """
+ self._client = FeaturestoreOnlineServingServiceClient(
+ credentials=credentials,
+ transport=transport,
+ client_options=client_options,
+ client_info=client_info,
+ )
+
+ async def read_feature_values(
+ self,
+ request: Union[
+ featurestore_online_service.ReadFeatureValuesRequest, dict
+ ] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> featurestore_online_service.ReadFeatureValuesResponse:
+ r"""Reads Feature values of a specific entity of an
+ EntityType. For reading feature values of multiple
+ entities of an EntityType, please use
+ StreamingReadFeatureValues.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_read_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreOnlineServingServiceAsyncClient()
+
+ # Initialize request argument(s)
+ feature_selector = aiplatform_v1.FeatureSelector()
+ feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.ReadFeatureValuesRequest(
+ entity_type="entity_type_value",
+ entity_id="entity_id_value",
+ feature_selector=feature_selector,
+ )
+
+ # Make the request
+ response = await client.read_feature_values(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ReadFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+ entity_type (:class:`str`):
+ Required. The resource name of the EntityType for the
+ entity being read. Value format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``.
+ For example, for a machine learning model predicting
+ user clicks on a website, an EntityType ID could be
+ ``user``.
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.ReadFeatureValuesResponse:
+ Response message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_online_service.ReadFeatureValuesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.read_feature_values,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def streaming_read_feature_values(
+ self,
+ request: Union[
+ featurestore_online_service.StreamingReadFeatureValuesRequest, dict
+ ] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Awaitable[
+ AsyncIterable[featurestore_online_service.ReadFeatureValuesResponse]
+ ]:
+ r"""Reads Feature values for multiple entities. Depending
+ on their size, data for different entities may be broken
+ up across multiple responses.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_streaming_read_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreOnlineServingServiceAsyncClient()
+
+ # Initialize request argument(s)
+ feature_selector = aiplatform_v1.FeatureSelector()
+ feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.StreamingReadFeatureValuesRequest(
+ entity_type="entity_type_value",
+ entity_ids=['entity_ids_value_1', 'entity_ids_value_2'],
+ feature_selector=feature_selector,
+ )
+
+ # Make the request
+ stream = await client.streaming_read_feature_values(request=request)
+
+ # Handle the response
+ async for response in stream:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.StreamingReadFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreOnlineServingService.StreamingFeatureValuesRead][].
+ entity_type (:class:`str`):
+ Required. The resource name of the entities' type. Value
+ format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``.
+ For example, for a machine learning model predicting
+ user clicks on a website, an EntityType ID could be
+ ``user``.
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ AsyncIterable[google.cloud.aiplatform_v1.types.ReadFeatureValuesResponse]:
+ Response message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_online_service.StreamingReadFeatureValuesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.streaming_read_feature_values,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.transport.close()
+
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution(
+ "google-cloud-aiplatform",
+ ).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+__all__ = ("FeaturestoreOnlineServingServiceAsyncClient",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py
new file mode 100644
index 0000000000..4e64bc8a22
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/client.py
@@ -0,0 +1,1368 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+import os
+import re
+from typing import Dict, Mapping, Optional, Iterable, Sequence, Tuple, Type, Union
+import pkg_resources
+
+from google.api_core import client_options as client_options_lib
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport import mtls # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+from google.auth.exceptions import MutualTLSChannelError # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
+from google.cloud.aiplatform_v1.types import featurestore_online_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from .transports.base import (
+ FeaturestoreOnlineServingServiceTransport,
+ DEFAULT_CLIENT_INFO,
+)
+from .transports.grpc import FeaturestoreOnlineServingServiceGrpcTransport
+from .transports.grpc_asyncio import (
+ FeaturestoreOnlineServingServiceGrpcAsyncIOTransport,
+)
+
+
+class FeaturestoreOnlineServingServiceClientMeta(type):
+ """Metaclass for the FeaturestoreOnlineServingService client.
+
+ This provides class-level methods for building and retrieving
+ support objects (e.g. transport) without polluting the client instance
+ objects.
+ """
+
+ _transport_registry = (
+ OrderedDict()
+ ) # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]]
+ _transport_registry["grpc"] = FeaturestoreOnlineServingServiceGrpcTransport
+ _transport_registry[
+ "grpc_asyncio"
+ ] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport
+
+ def get_transport_class(
+ cls,
+ label: str = None,
+ ) -> Type[FeaturestoreOnlineServingServiceTransport]:
+ """Returns an appropriate transport class.
+
+ Args:
+ label: The name of the desired transport. If none is
+ provided, then the first transport in the registry is used.
+
+ Returns:
+ The transport class to use.
+ """
+ # If a specific transport is requested, return that one.
+ if label:
+ return cls._transport_registry[label]
+
+ # No transport is requested; return the default (that is, the first one
+ # in the dictionary).
+ return next(iter(cls._transport_registry.values()))
+
+
+class FeaturestoreOnlineServingServiceClient(
+ metaclass=FeaturestoreOnlineServingServiceClientMeta
+):
+ """A service for serving online feature values."""
+
+ @staticmethod
+ def _get_default_mtls_endpoint(api_endpoint):
+ """Converts api endpoint to mTLS endpoint.
+
+ Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
+ "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
+ Args:
+ api_endpoint (Optional[str]): the api endpoint to convert.
+ Returns:
+ str: converted mTLS api endpoint.
+ """
+ if not api_endpoint:
+ return api_endpoint
+
+ mtls_endpoint_re = re.compile(
+ r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?"
+ )
+
+ m = mtls_endpoint_re.match(api_endpoint)
+ name, mtls, sandbox, googledomain = m.groups()
+ if mtls or not googledomain:
+ return api_endpoint
+
+ if sandbox:
+ return api_endpoint.replace(
+ "sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
+ )
+
+ return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+
+ DEFAULT_ENDPOINT = "aiplatform.googleapis.com"
+ DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
+ DEFAULT_ENDPOINT
+ )
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreOnlineServingServiceClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_info(info)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreOnlineServingServiceClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_file(filename)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ from_service_account_json = from_service_account_file
+
+ @property
+ def transport(self) -> FeaturestoreOnlineServingServiceTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ FeaturestoreOnlineServingServiceTransport: The transport used by the client
+ instance.
+ """
+ return self._transport
+
+ @staticmethod
+ def entity_type_path(
+ project: str,
+ location: str,
+ featurestore: str,
+ entity_type: str,
+ ) -> str:
+ """Returns a fully-qualified entity_type string."""
+ return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(
+ project=project,
+ location=location,
+ featurestore=featurestore,
+ entity_type=entity_type,
+ )
+
+ @staticmethod
+ def parse_entity_type_path(path: str) -> Dict[str, str]:
+ """Parses a entity_type path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_billing_account_path(
+ billing_account: str,
+ ) -> str:
+ """Returns a fully-qualified billing_account string."""
+ return "billingAccounts/{billing_account}".format(
+ billing_account=billing_account,
+ )
+
+ @staticmethod
+ def parse_common_billing_account_path(path: str) -> Dict[str, str]:
+ """Parse a billing_account path into its component segments."""
+ m = re.match(r"^billingAccounts/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_folder_path(
+ folder: str,
+ ) -> str:
+ """Returns a fully-qualified folder string."""
+ return "folders/{folder}".format(
+ folder=folder,
+ )
+
+ @staticmethod
+ def parse_common_folder_path(path: str) -> Dict[str, str]:
+ """Parse a folder path into its component segments."""
+ m = re.match(r"^folders/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_organization_path(
+ organization: str,
+ ) -> str:
+ """Returns a fully-qualified organization string."""
+ return "organizations/{organization}".format(
+ organization=organization,
+ )
+
+ @staticmethod
+ def parse_common_organization_path(path: str) -> Dict[str, str]:
+ """Parse a organization path into its component segments."""
+ m = re.match(r"^organizations/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_project_path(
+ project: str,
+ ) -> str:
+ """Returns a fully-qualified project string."""
+ return "projects/{project}".format(
+ project=project,
+ )
+
+ @staticmethod
+ def parse_common_project_path(path: str) -> Dict[str, str]:
+ """Parse a project path into its component segments."""
+ m = re.match(r"^projects/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def common_location_path(
+ project: str,
+ location: str,
+ ) -> str:
+ """Returns a fully-qualified location string."""
+ return "projects/{project}/locations/{location}".format(
+ project=project,
+ location=location,
+ )
+
+ @staticmethod
+ def parse_common_location_path(path: str) -> Dict[str, str]:
+ """Parse a location path into its component segments."""
+ m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path)
+ return m.groupdict() if m else {}
+
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[client_options_lib.ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ if client_options is None:
+ client_options = client_options_lib.ClientOptions()
+ use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")
+ use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
+ if use_client_cert not in ("true", "false"):
+ raise ValueError(
+ "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`"
+ )
+ if use_mtls_endpoint not in ("auto", "never", "always"):
+ raise MutualTLSChannelError(
+ "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`"
+ )
+
+ # Figure out the client cert source to use.
+ client_cert_source = None
+ if use_client_cert == "true":
+ if client_options.client_cert_source:
+ client_cert_source = client_options.client_cert_source
+ elif mtls.has_default_client_cert_source():
+ client_cert_source = mtls.default_client_cert_source()
+
+ # Figure out which api endpoint to use.
+ if client_options.api_endpoint is not None:
+ api_endpoint = client_options.api_endpoint
+ elif use_mtls_endpoint == "always" or (
+ use_mtls_endpoint == "auto" and client_cert_source
+ ):
+ api_endpoint = cls.DEFAULT_MTLS_ENDPOINT
+ else:
+ api_endpoint = cls.DEFAULT_ENDPOINT
+
+ return api_endpoint, client_cert_source
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[ga_credentials.Credentials] = None,
+ transport: Union[str, FeaturestoreOnlineServingServiceTransport, None] = None,
+ client_options: Optional[client_options_lib.ClientOptions] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ ) -> None:
+ """Instantiates the featurestore online serving service client.
+
+ Args:
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ transport (Union[str, FeaturestoreOnlineServingServiceTransport]): The
+ transport to use. If set to None, a transport is chosen
+ automatically.
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. It won't take effect if a ``transport`` instance is provided.
+ (1) The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
+ environment variable can also be used to override the endpoint:
+ "always" (always use the default mTLS endpoint), "never" (always
+ use the default regular endpoint) and "auto" (auto switch to the
+ default mTLS endpoint if client certificate is present, this is
+ the default value). However, the ``api_endpoint`` property takes
+ precedence if provided.
+ (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ is "true", then the ``client_cert_source`` property can be used
+ to provide client certificate for mutual TLS transport. If
+ not provided, the default SSL client certificate will be used if
+ present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
+ set, no client certificate will be used.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
+ creation failed for any reason.
+ """
+ if isinstance(client_options, dict):
+ client_options = client_options_lib.from_dict(client_options)
+ if client_options is None:
+ client_options = client_options_lib.ClientOptions()
+
+ api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(
+ client_options
+ )
+
+ api_key_value = getattr(client_options, "api_key", None)
+ if api_key_value and credentials:
+ raise ValueError(
+ "client_options.api_key and credentials are mutually exclusive"
+ )
+
+ # Save or instantiate the transport.
+ # Ordinarily, we provide the transport, but allowing a custom transport
+ # instance provides an extensibility point for unusual situations.
+ if isinstance(transport, FeaturestoreOnlineServingServiceTransport):
+ # transport is a FeaturestoreOnlineServingServiceTransport instance.
+ if credentials or client_options.credentials_file or api_key_value:
+ raise ValueError(
+ "When providing a transport instance, "
+ "provide its credentials directly."
+ )
+ if client_options.scopes:
+ raise ValueError(
+ "When providing a transport instance, provide its scopes "
+ "directly."
+ )
+ self._transport = transport
+ else:
+ import google.auth._default # type: ignore
+
+ if api_key_value and hasattr(
+ google.auth._default, "get_api_key_credentials"
+ ):
+ credentials = google.auth._default.get_api_key_credentials(
+ api_key_value
+ )
+
+ Transport = type(self).get_transport_class(transport)
+ self._transport = Transport(
+ credentials=credentials,
+ credentials_file=client_options.credentials_file,
+ host=api_endpoint,
+ scopes=client_options.scopes,
+ client_cert_source_for_mtls=client_cert_source_func,
+ quota_project_id=client_options.quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=True,
+ )
+
+ def read_feature_values(
+ self,
+ request: Union[
+ featurestore_online_service.ReadFeatureValuesRequest, dict
+ ] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> featurestore_online_service.ReadFeatureValuesResponse:
+ r"""Reads Feature values of a specific entity of an
+ EntityType. For reading feature values of multiple
+ entities of an EntityType, please use
+ StreamingReadFeatureValues.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_read_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreOnlineServingServiceClient()
+
+ # Initialize request argument(s)
+ feature_selector = aiplatform_v1.FeatureSelector()
+ feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.ReadFeatureValuesRequest(
+ entity_type="entity_type_value",
+ entity_id="entity_id_value",
+ feature_selector=feature_selector,
+ )
+
+ # Make the request
+ response = client.read_feature_values(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ReadFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+ entity_type (str):
+ Required. The resource name of the EntityType for the
+ entity being read. Value format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``.
+ For example, for a machine learning model predicting
+ user clicks on a website, an EntityType ID could be
+ ``user``.
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.ReadFeatureValuesResponse:
+ Response message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a featurestore_online_service.ReadFeatureValuesRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(
+ request, featurestore_online_service.ReadFeatureValuesRequest
+ ):
+ request = featurestore_online_service.ReadFeatureValuesRequest(request)
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[self._transport.read_feature_values]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def streaming_read_feature_values(
+ self,
+ request: Union[
+ featurestore_online_service.StreamingReadFeatureValuesRequest, dict
+ ] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> Iterable[featurestore_online_service.ReadFeatureValuesResponse]:
+ r"""Reads Feature values for multiple entities. Depending
+ on their size, data for different entities may be broken
+ up across multiple responses.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ def sample_streaming_read_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreOnlineServingServiceClient()
+
+ # Initialize request argument(s)
+ feature_selector = aiplatform_v1.FeatureSelector()
+ feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.StreamingReadFeatureValuesRequest(
+ entity_type="entity_type_value",
+ entity_ids=['entity_ids_value_1', 'entity_ids_value_2'],
+ feature_selector=feature_selector,
+ )
+
+ # Make the request
+ stream = client.streaming_read_feature_values(request=request)
+
+ # Handle the response
+ for response in stream:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.StreamingReadFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreOnlineServingService.StreamingFeatureValuesRead][].
+ entity_type (str):
+ Required. The resource name of the entities' type. Value
+ format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``.
+ For example, for a machine learning model predicting
+ user clicks on a website, an EntityType ID could be
+ ``user``.
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ Iterable[google.cloud.aiplatform_v1.types.ReadFeatureValuesResponse]:
+ Response message for
+ [FeaturestoreOnlineServingService.ReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreOnlineServingService.ReadFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ # Minor optimization to avoid making a copy if the user passes
+ # in a featurestore_online_service.StreamingReadFeatureValuesRequest.
+ # There's no risk of modifying the input as we've already verified
+ # there are no flattened fields.
+ if not isinstance(
+ request, featurestore_online_service.StreamingReadFeatureValuesRequest
+ ):
+ request = featurestore_online_service.StreamingReadFeatureValuesRequest(
+ request
+ )
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = self._transport._wrapped_methods[
+ self._transport.streaming_read_feature_values
+ ]
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ """Releases underlying transport's resources.
+
+ .. warning::
+ ONLY use as a context manager if the transport is NOT shared
+ with other clients! Exiting the with block will CLOSE the transport
+ and may cause errors in other clients!
+ """
+ self.transport.close()
+
+ def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution(
+ "google-cloud-aiplatform",
+ ).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+__all__ = ("FeaturestoreOnlineServingServiceClient",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/__init__.py
new file mode 100644
index 0000000000..c929d350e6
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/__init__.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+from typing import Dict, Type
+
+from .base import FeaturestoreOnlineServingServiceTransport
+from .grpc import FeaturestoreOnlineServingServiceGrpcTransport
+from .grpc_asyncio import FeaturestoreOnlineServingServiceGrpcAsyncIOTransport
+
+
+# Compile a registry of transports.
+_transport_registry = (
+ OrderedDict()
+) # type: Dict[str, Type[FeaturestoreOnlineServingServiceTransport]]
+_transport_registry["grpc"] = FeaturestoreOnlineServingServiceGrpcTransport
+_transport_registry[
+ "grpc_asyncio"
+] = FeaturestoreOnlineServingServiceGrpcAsyncIOTransport
+
+__all__ = (
+ "FeaturestoreOnlineServingServiceTransport",
+ "FeaturestoreOnlineServingServiceGrpcTransport",
+ "FeaturestoreOnlineServingServiceGrpcAsyncIOTransport",
+)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py
new file mode 100644
index 0000000000..e7ff1c284e
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/base.py
@@ -0,0 +1,271 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import abc
+from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
+import pkg_resources
+
+import google.auth # type: ignore
+import google.api_core
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+from google.cloud.aiplatform_v1.types import featurestore_online_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution(
+ "google-cloud-aiplatform",
+ ).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+class FeaturestoreOnlineServingServiceTransport(abc.ABC):
+ """Abstract transport class for FeaturestoreOnlineServingService."""
+
+ AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",)
+
+ DEFAULT_HOST: str = "aiplatform.googleapis.com"
+
+ def __init__(
+ self,
+ *,
+ host: str = DEFAULT_HOST,
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ **kwargs,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is mutually exclusive with credentials.
+ scopes (Optional[Sequence[str]]): A list of scopes.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+ """
+
+ # Save the hostname. Default to port 443 (HTTPS) if none is specified.
+ if ":" not in host:
+ host += ":443"
+ self._host = host
+
+ scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
+
+ # Save the scopes.
+ self._scopes = scopes
+
+ # If no credentials are provided, then determine the appropriate
+ # defaults.
+ if credentials and credentials_file:
+ raise core_exceptions.DuplicateCredentialArgs(
+ "'credentials_file' and 'credentials' are mutually exclusive"
+ )
+
+ if credentials_file is not None:
+ credentials, _ = google.auth.load_credentials_from_file(
+ credentials_file, **scopes_kwargs, quota_project_id=quota_project_id
+ )
+ elif credentials is None:
+ credentials, _ = google.auth.default(
+ **scopes_kwargs, quota_project_id=quota_project_id
+ )
+
+ # If the credentials are service account credentials, then always try to use self signed JWT.
+ if (
+ always_use_jwt_access
+ and isinstance(credentials, service_account.Credentials)
+ and hasattr(service_account.Credentials, "with_always_use_jwt_access")
+ ):
+ credentials = credentials.with_always_use_jwt_access(True)
+
+ # Save the credentials.
+ self._credentials = credentials
+
+ def _prep_wrapped_messages(self, client_info):
+ # Precompute the wrapped methods.
+ self._wrapped_methods = {
+ self.read_feature_values: gapic_v1.method.wrap_method(
+ self.read_feature_values,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ self.streaming_read_feature_values: gapic_v1.method.wrap_method(
+ self.streaming_read_feature_values,
+ default_timeout=None,
+ client_info=client_info,
+ ),
+ }
+
+ def close(self):
+ """Closes resources associated with the transport.
+
+ .. warning::
+ Only call this method if the transport is NOT shared
+ with other clients - this may cause errors in other clients!
+ """
+ raise NotImplementedError()
+
+ @property
+ def read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.ReadFeatureValuesRequest],
+ Union[
+ featurestore_online_service.ReadFeatureValuesResponse,
+ Awaitable[featurestore_online_service.ReadFeatureValuesResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def streaming_read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.StreamingReadFeatureValuesRequest],
+ Union[
+ featurestore_online_service.ReadFeatureValuesResponse,
+ Awaitable[featurestore_online_service.ReadFeatureValuesResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest],
+ Union[
+ operations_pb2.ListOperationsResponse,
+ Awaitable[operations_pb2.ListOperationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.GetOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None,]:
+ raise NotImplementedError()
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[
+ [operations_pb2.WaitOperationRequest],
+ Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.SetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.GetIamPolicyRequest],
+ Union[policy_pb2.Policy, Awaitable[policy_pb2.Policy]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ Union[
+ iam_policy_pb2.TestIamPermissionsResponse,
+ Awaitable[iam_policy_pb2.TestIamPermissionsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[
+ [locations_pb2.GetLocationRequest],
+ Union[locations_pb2.Location, Awaitable[locations_pb2.Location]],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest],
+ Union[
+ locations_pb2.ListLocationsResponse,
+ Awaitable[locations_pb2.ListLocationsResponse],
+ ],
+ ]:
+ raise NotImplementedError()
+
+ @property
+ def kind(self) -> str:
+ raise NotImplementedError()
+
+
+__all__ = ("FeaturestoreOnlineServingServiceTransport",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py
new file mode 100644
index 0000000000..150569313b
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc.py
@@ -0,0 +1,512 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import warnings
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
+
+from google.api_core import grpc_helpers
+from google.api_core import gapic_v1
+import google.auth # type: ignore
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+
+import grpc # type: ignore
+
+from google.cloud.aiplatform_v1.types import featurestore_online_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from .base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO
+
+
+class FeaturestoreOnlineServingServiceGrpcTransport(
+ FeaturestoreOnlineServingServiceTransport
+):
+ """gRPC backend transport for FeaturestoreOnlineServingService.
+
+ A service for serving online feature values.
+
+ This class defines the same methods as the primary client, so the
+ primary client can load the underlying transport implementation
+ and call it.
+
+ It sends protocol buffers over the wire using gRPC (which is built on
+ top of HTTP/2); the ``grpcio`` package must be installed.
+ """
+
+ _stubs: Dict[str, Callable]
+
+ def __init__(
+ self,
+ *,
+ host: str = "aiplatform.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: str = None,
+ scopes: Sequence[str] = None,
+ channel: grpc.Channel = None,
+ api_mtls_endpoint: str = None,
+ client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
+ ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
+ quota_project_id: Optional[str] = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ This argument is ignored if ``channel`` is provided.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional(Sequence[str])): A list of scopes. This argument is
+ ignored if ``channel`` is provided.
+ channel (Optional[grpc.Channel]): A ``Channel`` instance through
+ which to make calls.
+ api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
+ If provided, it overrides the ``host`` argument and tries to create
+ a mutual TLS channel with client SSL credentials from
+ ``client_cert_source`` or application default SSL credentials.
+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ Deprecated. A callback to provide client SSL certificate bytes and
+ private key bytes, both in PEM format. It is ignored if
+ ``api_mtls_endpoint`` is None.
+ ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
+ for the grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure a mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport
+ creation failed for any reason.
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+ self._grpc_channel = None
+ self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
+
+ if channel:
+ # Ignore credentials if a channel was passed.
+ credentials = False
+ # If a channel was explicitly provided, set it.
+ self._grpc_channel = channel
+ self._ssl_channel_credentials = None
+
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
+
+ else:
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
+ )
+
+ if not self._grpc_channel:
+ self._grpc_channel = type(self).create_channel(
+ self._host,
+ # use the credentials which are saved
+ credentials=self._credentials,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
+ quota_project_id=quota_project_id,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
+
+ @classmethod
+ def create_channel(
+ cls,
+ host: str = "aiplatform.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: str = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ **kwargs,
+ ) -> grpc.Channel:
+ """Create and return a gRPC channel object.
+ Args:
+ host (Optional[str]): The host for the channel to use.
+ credentials (Optional[~.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify this application to the service. If
+ none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is mutually exclusive with credentials.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ kwargs (Optional[dict]): Keyword arguments, which are passed to the
+ channel creation.
+ Returns:
+ grpc.Channel: A gRPC channel object.
+
+ Raises:
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+
+ return grpc_helpers.create_channel(
+ host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
+ **kwargs,
+ )
+
+ @property
+ def grpc_channel(self) -> grpc.Channel:
+ """Return the channel designed to connect to this service."""
+ return self._grpc_channel
+
+ @property
+ def read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.ReadFeatureValuesRequest],
+ featurestore_online_service.ReadFeatureValuesResponse,
+ ]:
+ r"""Return a callable for the read feature values method over gRPC.
+
+ Reads Feature values of a specific entity of an
+ EntityType. For reading feature values of multiple
+ entities of an EntityType, please use
+ StreamingReadFeatureValues.
+
+ Returns:
+ Callable[[~.ReadFeatureValuesRequest],
+ ~.ReadFeatureValuesResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "read_feature_values" not in self._stubs:
+ self._stubs["read_feature_values"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1.FeaturestoreOnlineServingService/ReadFeatureValues",
+ request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize,
+ response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize,
+ )
+ return self._stubs["read_feature_values"]
+
+ @property
+ def streaming_read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.StreamingReadFeatureValuesRequest],
+ featurestore_online_service.ReadFeatureValuesResponse,
+ ]:
+ r"""Return a callable for the streaming read feature values method over gRPC.
+
+ Reads Feature values for multiple entities. Depending
+ on their size, data for different entities may be broken
+ up across multiple responses.
+
+ Returns:
+ Callable[[~.StreamingReadFeatureValuesRequest],
+ ~.ReadFeatureValuesResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "streaming_read_feature_values" not in self._stubs:
+ self._stubs[
+ "streaming_read_feature_values"
+ ] = self.grpc_channel.unary_stream(
+ "/google.cloud.aiplatform.v1.FeaturestoreOnlineServingService/StreamingReadFeatureValues",
+ request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize,
+ response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize,
+ )
+ return self._stubs["streaming_read_feature_values"]
+
+ def close(self):
+ self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
+ @property
+ def kind(self) -> str:
+ return "grpc"
+
+
+__all__ = ("FeaturestoreOnlineServingServiceGrpcTransport",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py
new file mode 100644
index 0000000000..9bfcb672aa
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_online_serving_service/transports/grpc_asyncio.py
@@ -0,0 +1,511 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import warnings
+from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union
+
+from google.api_core import gapic_v1
+from google.api_core import grpc_helpers_async
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+
+import grpc # type: ignore
+from grpc.experimental import aio # type: ignore
+
+from google.cloud.aiplatform_v1.types import featurestore_online_service
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from .base import FeaturestoreOnlineServingServiceTransport, DEFAULT_CLIENT_INFO
+from .grpc import FeaturestoreOnlineServingServiceGrpcTransport
+
+
+class FeaturestoreOnlineServingServiceGrpcAsyncIOTransport(
+ FeaturestoreOnlineServingServiceTransport
+):
+ """gRPC AsyncIO backend transport for FeaturestoreOnlineServingService.
+
+ A service for serving online feature values.
+
+ This class defines the same methods as the primary client, so the
+ primary client can load the underlying transport implementation
+ and call it.
+
+ It sends protocol buffers over the wire using gRPC (which is built on
+ top of HTTP/2); the ``grpcio`` package must be installed.
+ """
+
+ _grpc_channel: aio.Channel
+ _stubs: Dict[str, Callable] = {}
+
+ @classmethod
+ def create_channel(
+ cls,
+ host: str = "aiplatform.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ quota_project_id: Optional[str] = None,
+ **kwargs,
+ ) -> aio.Channel:
+ """Create and return a gRPC AsyncIO channel object.
+ Args:
+ host (Optional[str]): The host for the channel to use.
+ credentials (Optional[~.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify this application to the service. If
+ none are specified, the client will attempt to ascertain
+ the credentials from the environment.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ kwargs (Optional[dict]): Keyword arguments, which are passed to the
+ channel creation.
+ Returns:
+ aio.Channel: A gRPC AsyncIO channel object.
+ """
+
+ return grpc_helpers_async.create_channel(
+ host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ quota_project_id=quota_project_id,
+ default_scopes=cls.AUTH_SCOPES,
+ scopes=scopes,
+ default_host=cls.DEFAULT_HOST,
+ **kwargs,
+ )
+
+ def __init__(
+ self,
+ *,
+ host: str = "aiplatform.googleapis.com",
+ credentials: ga_credentials.Credentials = None,
+ credentials_file: Optional[str] = None,
+ scopes: Optional[Sequence[str]] = None,
+ channel: aio.Channel = None,
+ api_mtls_endpoint: str = None,
+ client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
+ ssl_channel_credentials: grpc.ChannelCredentials = None,
+ client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
+ quota_project_id=None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ always_use_jwt_access: Optional[bool] = False,
+ ) -> None:
+ """Instantiate the transport.
+
+ Args:
+ host (Optional[str]):
+ The hostname to connect to.
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ This argument is ignored if ``channel`` is provided.
+ credentials_file (Optional[str]): A file with credentials that can
+ be loaded with :func:`google.auth.load_credentials_from_file`.
+ This argument is ignored if ``channel`` is provided.
+ scopes (Optional[Sequence[str]]): A optional list of scopes needed for this
+ service. These are only used when credentials are not specified and
+ are passed to :func:`google.auth.default`.
+ channel (Optional[aio.Channel]): A ``Channel`` instance through
+ which to make calls.
+ api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint.
+ If provided, it overrides the ``host`` argument and tries to create
+ a mutual TLS channel with client SSL credentials from
+ ``client_cert_source`` or application default SSL credentials.
+ client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ Deprecated. A callback to provide client SSL certificate bytes and
+ private key bytes, both in PEM format. It is ignored if
+ ``api_mtls_endpoint`` is None.
+ ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
+ for the grpc channel. It is ignored if ``channel`` is provided.
+ client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
+ A callback to provide client certificate bytes and private key bytes,
+ both in PEM format. It is used to configure a mutual TLS channel. It is
+ ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
+ quota_project_id (Optional[str]): An optional project to use for billing
+ and quota.
+ client_info (google.api_core.gapic_v1.client_info.ClientInfo):
+ The client info used to send a user-agent string along with
+ API requests. If ``None``, then default info will be used.
+ Generally, you only need to set this if you're developing
+ your own client library.
+ always_use_jwt_access (Optional[bool]): Whether self signed JWT should
+ be used for service account credentials.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
+ and ``credentials_file`` are passed.
+ """
+ self._grpc_channel = None
+ self._ssl_channel_credentials = ssl_channel_credentials
+ self._stubs: Dict[str, Callable] = {}
+
+ if api_mtls_endpoint:
+ warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
+ if client_cert_source:
+ warnings.warn("client_cert_source is deprecated", DeprecationWarning)
+
+ if channel:
+ # Ignore credentials if a channel was passed.
+ credentials = False
+ # If a channel was explicitly provided, set it.
+ self._grpc_channel = channel
+ self._ssl_channel_credentials = None
+ else:
+ if api_mtls_endpoint:
+ host = api_mtls_endpoint
+
+ # Create SSL credentials with client_cert_source or application
+ # default SSL credentials.
+ if client_cert_source:
+ cert, key = client_cert_source()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_channel_credentials = SslCredentials().ssl_credentials
+
+ else:
+ if client_cert_source_for_mtls and not ssl_channel_credentials:
+ cert, key = client_cert_source_for_mtls()
+ self._ssl_channel_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+
+ # The base transport sets the host, credentials and scopes
+ super().__init__(
+ host=host,
+ credentials=credentials,
+ credentials_file=credentials_file,
+ scopes=scopes,
+ quota_project_id=quota_project_id,
+ client_info=client_info,
+ always_use_jwt_access=always_use_jwt_access,
+ )
+
+ if not self._grpc_channel:
+ self._grpc_channel = type(self).create_channel(
+ self._host,
+ # use the credentials which are saved
+ credentials=self._credentials,
+ # Set ``credentials_file`` to ``None`` here as
+ # the credentials that we saved earlier should be used.
+ credentials_file=None,
+ scopes=self._scopes,
+ ssl_credentials=self._ssl_channel_credentials,
+ quota_project_id=quota_project_id,
+ options=[
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ],
+ )
+
+ # Wrap messages. This must be done after self._grpc_channel exists
+ self._prep_wrapped_messages(client_info)
+
+ @property
+ def grpc_channel(self) -> aio.Channel:
+ """Create the channel designed to connect to this service.
+
+ This property caches on the instance; repeated calls return
+ the same channel.
+ """
+ # Return the channel from cache.
+ return self._grpc_channel
+
+ @property
+ def read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.ReadFeatureValuesRequest],
+ Awaitable[featurestore_online_service.ReadFeatureValuesResponse],
+ ]:
+ r"""Return a callable for the read feature values method over gRPC.
+
+ Reads Feature values of a specific entity of an
+ EntityType. For reading feature values of multiple
+ entities of an EntityType, please use
+ StreamingReadFeatureValues.
+
+ Returns:
+ Callable[[~.ReadFeatureValuesRequest],
+ Awaitable[~.ReadFeatureValuesResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "read_feature_values" not in self._stubs:
+ self._stubs["read_feature_values"] = self.grpc_channel.unary_unary(
+ "/google.cloud.aiplatform.v1.FeaturestoreOnlineServingService/ReadFeatureValues",
+ request_serializer=featurestore_online_service.ReadFeatureValuesRequest.serialize,
+ response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize,
+ )
+ return self._stubs["read_feature_values"]
+
+ @property
+ def streaming_read_feature_values(
+ self,
+ ) -> Callable[
+ [featurestore_online_service.StreamingReadFeatureValuesRequest],
+ Awaitable[featurestore_online_service.ReadFeatureValuesResponse],
+ ]:
+ r"""Return a callable for the streaming read feature values method over gRPC.
+
+ Reads Feature values for multiple entities. Depending
+ on their size, data for different entities may be broken
+ up across multiple responses.
+
+ Returns:
+ Callable[[~.StreamingReadFeatureValuesRequest],
+ Awaitable[~.ReadFeatureValuesResponse]]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "streaming_read_feature_values" not in self._stubs:
+ self._stubs[
+ "streaming_read_feature_values"
+ ] = self.grpc_channel.unary_stream(
+ "/google.cloud.aiplatform.v1.FeaturestoreOnlineServingService/StreamingReadFeatureValues",
+ request_serializer=featurestore_online_service.StreamingReadFeatureValuesRequest.serialize,
+ response_deserializer=featurestore_online_service.ReadFeatureValuesResponse.deserialize,
+ )
+ return self._stubs["streaming_read_feature_values"]
+
+ def close(self):
+ return self.grpc_channel.close()
+
+ @property
+ def delete_operation(
+ self,
+ ) -> Callable[[operations_pb2.DeleteOperationRequest], None]:
+ r"""Return a callable for the delete_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/DeleteOperation",
+ request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["delete_operation"]
+
+ @property
+ def cancel_operation(
+ self,
+ ) -> Callable[[operations_pb2.CancelOperationRequest], None]:
+ r"""Return a callable for the cancel_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "cancel_operation" not in self._stubs:
+ self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/CancelOperation",
+ request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["cancel_operation"]
+
+ @property
+ def wait_operation(
+ self,
+ ) -> Callable[[operations_pb2.WaitOperationRequest], None]:
+ r"""Return a callable for the wait_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "delete_operation" not in self._stubs:
+ self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/WaitOperation",
+ request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
+ response_deserializer=None,
+ )
+ return self._stubs["wait_operation"]
+
+ @property
+ def get_operation(
+ self,
+ ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]:
+ r"""Return a callable for the get_operation method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_operation" not in self._stubs:
+ self._stubs["get_operation"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/GetOperation",
+ request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
+ response_deserializer=operations_pb2.Operation.FromString,
+ )
+ return self._stubs["get_operation"]
+
+ @property
+ def list_operations(
+ self,
+ ) -> Callable[
+ [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse
+ ]:
+ r"""Return a callable for the list_operations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_operations" not in self._stubs:
+ self._stubs["list_operations"] = self.grpc_channel.unary_unary(
+ "/google.longrunning.Operations/ListOperations",
+ request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
+ response_deserializer=operations_pb2.ListOperationsResponse.FromString,
+ )
+ return self._stubs["list_operations"]
+
+ @property
+ def list_locations(
+ self,
+ ) -> Callable[
+ [locations_pb2.ListLocationsRequest], locations_pb2.ListLocationsResponse
+ ]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "list_locations" not in self._stubs:
+ self._stubs["list_locations"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/ListLocations",
+ request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
+ response_deserializer=locations_pb2.ListLocationsResponse.FromString,
+ )
+ return self._stubs["list_locations"]
+
+ @property
+ def get_location(
+ self,
+ ) -> Callable[[locations_pb2.GetLocationRequest], locations_pb2.Location]:
+ r"""Return a callable for the list locations method over gRPC."""
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_location" not in self._stubs:
+ self._stubs["get_location"] = self.grpc_channel.unary_unary(
+ "/google.cloud.location.Locations/GetLocation",
+ request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
+ response_deserializer=locations_pb2.Location.FromString,
+ )
+ return self._stubs["get_location"]
+
+ @property
+ def set_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.SetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the set iam policy method over gRPC.
+ Sets the IAM access control policy on the specified
+ function. Replaces any existing policy.
+ Returns:
+ Callable[[~.SetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "set_iam_policy" not in self._stubs:
+ self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/SetIamPolicy",
+ request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["set_iam_policy"]
+
+ @property
+ def get_iam_policy(
+ self,
+ ) -> Callable[[iam_policy_pb2.GetIamPolicyRequest], policy_pb2.Policy]:
+ r"""Return a callable for the get iam policy method over gRPC.
+ Gets the IAM access control policy for a function.
+ Returns an empty policy if the function exists and does
+ not have a policy set.
+ Returns:
+ Callable[[~.GetIamPolicyRequest],
+ ~.Policy]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "get_iam_policy" not in self._stubs:
+ self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/GetIamPolicy",
+ request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
+ response_deserializer=policy_pb2.Policy.FromString,
+ )
+ return self._stubs["get_iam_policy"]
+
+ @property
+ def test_iam_permissions(
+ self,
+ ) -> Callable[
+ [iam_policy_pb2.TestIamPermissionsRequest],
+ iam_policy_pb2.TestIamPermissionsResponse,
+ ]:
+ r"""Return a callable for the test iam permissions method over gRPC.
+ Tests the specified permissions against the IAM access control
+ policy for a function. If the function does not exist, this will
+ return an empty set of permissions, not a NOT_FOUND error.
+ Returns:
+ Callable[[~.TestIamPermissionsRequest],
+ ~.TestIamPermissionsResponse]:
+ A function that, when called, will call the underlying RPC
+ on the server.
+ """
+ # Generate a "stub function" on-the-fly which will actually make
+ # the request.
+ # gRPC handles serialization and deserialization, so we just need
+ # to pass in the functions for each.
+ if "test_iam_permissions" not in self._stubs:
+ self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
+ "/google.iam.v1.IAMPolicy/TestIamPermissions",
+ request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
+ response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
+ )
+ return self._stubs["test_iam_permissions"]
+
+
+__all__ = ("FeaturestoreOnlineServingServiceGrpcAsyncIOTransport",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/__init__.py b/google/cloud/aiplatform_v1/services/featurestore_service/__init__.py
new file mode 100644
index 0000000000..3f53c57568
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_service/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from .client import FeaturestoreServiceClient
+from .async_client import FeaturestoreServiceAsyncClient
+
+__all__ = (
+ "FeaturestoreServiceClient",
+ "FeaturestoreServiceAsyncClient",
+)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py
new file mode 100644
index 0000000000..c2f7102c03
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_service/async_client.py
@@ -0,0 +1,3464 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+import functools
+import re
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+import pkg_resources
+
+from google.api_core.client_options import ClientOptions
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.auth import credentials as ga_credentials # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
+from google.api_core import operation as gac_operation # type: ignore
+from google.api_core import operation_async # type: ignore
+from google.cloud.aiplatform_v1.services.featurestore_service import pagers
+from google.cloud.aiplatform_v1.types import encryption_spec
+from google.cloud.aiplatform_v1.types import entity_type
+from google.cloud.aiplatform_v1.types import entity_type as gca_entity_type
+from google.cloud.aiplatform_v1.types import feature
+from google.cloud.aiplatform_v1.types import feature as gca_feature
+from google.cloud.aiplatform_v1.types import featurestore
+from google.cloud.aiplatform_v1.types import featurestore as gca_featurestore
+from google.cloud.aiplatform_v1.types import featurestore_monitoring
+from google.cloud.aiplatform_v1.types import featurestore_service
+from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from google.protobuf import empty_pb2 # type: ignore
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+from .transports.base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO
+from .transports.grpc_asyncio import FeaturestoreServiceGrpcAsyncIOTransport
+from .client import FeaturestoreServiceClient
+
+
+class FeaturestoreServiceAsyncClient:
+ """The service that handles CRUD and List for resources for
+ Featurestore.
+ """
+
+ _client: FeaturestoreServiceClient
+
+ DEFAULT_ENDPOINT = FeaturestoreServiceClient.DEFAULT_ENDPOINT
+ DEFAULT_MTLS_ENDPOINT = FeaturestoreServiceClient.DEFAULT_MTLS_ENDPOINT
+
+ entity_type_path = staticmethod(FeaturestoreServiceClient.entity_type_path)
+ parse_entity_type_path = staticmethod(
+ FeaturestoreServiceClient.parse_entity_type_path
+ )
+ feature_path = staticmethod(FeaturestoreServiceClient.feature_path)
+ parse_feature_path = staticmethod(FeaturestoreServiceClient.parse_feature_path)
+ featurestore_path = staticmethod(FeaturestoreServiceClient.featurestore_path)
+ parse_featurestore_path = staticmethod(
+ FeaturestoreServiceClient.parse_featurestore_path
+ )
+ common_billing_account_path = staticmethod(
+ FeaturestoreServiceClient.common_billing_account_path
+ )
+ parse_common_billing_account_path = staticmethod(
+ FeaturestoreServiceClient.parse_common_billing_account_path
+ )
+ common_folder_path = staticmethod(FeaturestoreServiceClient.common_folder_path)
+ parse_common_folder_path = staticmethod(
+ FeaturestoreServiceClient.parse_common_folder_path
+ )
+ common_organization_path = staticmethod(
+ FeaturestoreServiceClient.common_organization_path
+ )
+ parse_common_organization_path = staticmethod(
+ FeaturestoreServiceClient.parse_common_organization_path
+ )
+ common_project_path = staticmethod(FeaturestoreServiceClient.common_project_path)
+ parse_common_project_path = staticmethod(
+ FeaturestoreServiceClient.parse_common_project_path
+ )
+ common_location_path = staticmethod(FeaturestoreServiceClient.common_location_path)
+ parse_common_location_path = staticmethod(
+ FeaturestoreServiceClient.parse_common_location_path
+ )
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreServiceAsyncClient: The constructed client.
+ """
+ return FeaturestoreServiceClient.from_service_account_info.__func__(FeaturestoreServiceAsyncClient, info, *args, **kwargs) # type: ignore
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreServiceAsyncClient: The constructed client.
+ """
+ return FeaturestoreServiceClient.from_service_account_file.__func__(FeaturestoreServiceAsyncClient, filename, *args, **kwargs) # type: ignore
+
+ from_service_account_json = from_service_account_file
+
+ @classmethod
+ def get_mtls_endpoint_and_cert_source(
+ cls, client_options: Optional[ClientOptions] = None
+ ):
+ """Return the API endpoint and client cert source for mutual TLS.
+
+ The client cert source is determined in the following order:
+ (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the
+ client cert source is None.
+ (2) if `client_options.client_cert_source` is provided, use the provided one; if the
+ default client cert source exists, use the default one; otherwise the client cert
+ source is None.
+
+ The API endpoint is determined in the following order:
+ (1) if `client_options.api_endpoint` if provided, use the provided one.
+ (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the
+ default mTLS endpoint; if the environment variabel is "never", use the default API
+ endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise
+ use the default API endpoint.
+
+ More details can be found at https://google.aip.dev/auth/4114.
+
+ Args:
+ client_options (google.api_core.client_options.ClientOptions): Custom options for the
+ client. Only the `api_endpoint` and `client_cert_source` properties may be used
+ in this method.
+
+ Returns:
+ Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the
+ client cert source to use.
+
+ Raises:
+ google.auth.exceptions.MutualTLSChannelError: If any errors happen.
+ """
+ return FeaturestoreServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore
+
+ @property
+ def transport(self) -> FeaturestoreServiceTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ FeaturestoreServiceTransport: The transport used by the client instance.
+ """
+ return self._client.transport
+
+ get_transport_class = functools.partial(
+ type(FeaturestoreServiceClient).get_transport_class,
+ type(FeaturestoreServiceClient),
+ )
+
+ def __init__(
+ self,
+ *,
+ credentials: ga_credentials.Credentials = None,
+ transport: Union[str, FeaturestoreServiceTransport] = "grpc_asyncio",
+ client_options: ClientOptions = None,
+ client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
+ ) -> None:
+ """Instantiates the featurestore service client.
+
+ Args:
+ credentials (Optional[google.auth.credentials.Credentials]): The
+ authorization credentials to attach to requests. These
+ credentials identify the application to the service; if none
+ are specified, the client will attempt to ascertain the
+ credentials from the environment.
+ transport (Union[str, ~.FeaturestoreServiceTransport]): The
+ transport to use. If set to None, a transport is chosen
+ automatically.
+ client_options (ClientOptions): Custom options for the client. It
+ won't take effect if a ``transport`` instance is provided.
+ (1) The ``api_endpoint`` property can be used to override the
+ default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT
+ environment variable can also be used to override the endpoint:
+ "always" (always use the default mTLS endpoint), "never" (always
+ use the default regular endpoint) and "auto" (auto switch to the
+ default mTLS endpoint if client certificate is present, this is
+ the default value). However, the ``api_endpoint`` property takes
+ precedence if provided.
+ (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable
+ is "true", then the ``client_cert_source`` property can be used
+ to provide client certificate for mutual TLS transport. If
+ not provided, the default SSL client certificate will be used if
+ present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not
+ set, no client certificate will be used.
+
+ Raises:
+ google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport
+ creation failed for any reason.
+ """
+ self._client = FeaturestoreServiceClient(
+ credentials=credentials,
+ transport=transport,
+ client_options=client_options,
+ client_info=client_info,
+ )
+
+ async def create_featurestore(
+ self,
+ request: Union[featurestore_service.CreateFeaturestoreRequest, dict] = None,
+ *,
+ parent: str = None,
+ featurestore: gca_featurestore.Featurestore = None,
+ featurestore_id: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Creates a new Featurestore in a given project and
+ location.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_create_featurestore():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.CreateFeaturestoreRequest(
+ parent="parent_value",
+ featurestore_id="featurestore_id_value",
+ )
+
+ # Make the request
+ operation = client.create_featurestore(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.CreateFeaturestoreRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.CreateFeaturestore][google.cloud.aiplatform.v1.FeaturestoreService.CreateFeaturestore].
+ parent (:class:`str`):
+ Required. The resource name of the Location to create
+ Featurestores. Format:
+ ``projects/{project}/locations/{location}'``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ featurestore (:class:`google.cloud.aiplatform_v1.types.Featurestore`):
+ Required. The Featurestore to create.
+ This corresponds to the ``featurestore`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ featurestore_id (:class:`str`):
+ Required. The ID to use for this Featurestore, which
+ will become the final component of the Featurestore's
+ resource name.
+
+ This value may be up to 60 characters, and valid
+ characters are ``[a-z0-9_]``. The first character cannot
+ be a number.
+
+ The value must be unique within the project and
+ location.
+
+ This corresponds to the ``featurestore_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.Featurestore` Vertex AI Feature Store provides a centralized repository for organizing,
+ storing, and serving ML features. The Featurestore is
+ a top-level container for your features and their
+ values.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, featurestore, featurestore_id])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.CreateFeaturestoreRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if featurestore is not None:
+ request.featurestore = featurestore
+ if featurestore_id is not None:
+ request.featurestore_id = featurestore_id
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.create_featurestore,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ gca_featurestore.Featurestore,
+ metadata_type=featurestore_service.CreateFeaturestoreOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_featurestore(
+ self,
+ request: Union[featurestore_service.GetFeaturestoreRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> featurestore.Featurestore:
+ r"""Gets details of a single Featurestore.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_featurestore():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetFeaturestoreRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_featurestore(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.GetFeaturestoreRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.GetFeaturestore][google.cloud.aiplatform.v1.FeaturestoreService.GetFeaturestore].
+ name (:class:`str`):
+ Required. The name of the
+ Featurestore resource.
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.Featurestore:
+ Vertex AI Feature Store provides a
+ centralized repository for organizing,
+ storing, and serving ML features. The
+ Featurestore is a top-level container
+ for your features and their values.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.GetFeaturestoreRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_featurestore,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_featurestores(
+ self,
+ request: Union[featurestore_service.ListFeaturestoresRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListFeaturestoresAsyncPager:
+ r"""Lists Featurestores in a given project and location.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_featurestores():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListFeaturestoresRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_featurestores(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ListFeaturestoresRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.ListFeaturestores][google.cloud.aiplatform.v1.FeaturestoreService.ListFeaturestores].
+ parent (:class:`str`):
+ Required. The resource name of the Location to list
+ Featurestores. Format:
+ ``projects/{project}/locations/{location}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.services.featurestore_service.pagers.ListFeaturestoresAsyncPager:
+ Response message for
+ [FeaturestoreService.ListFeaturestores][google.cloud.aiplatform.v1.FeaturestoreService.ListFeaturestores].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.ListFeaturestoresRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_featurestores,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListFeaturestoresAsyncPager(
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def update_featurestore(
+ self,
+ request: Union[featurestore_service.UpdateFeaturestoreRequest, dict] = None,
+ *,
+ featurestore: gca_featurestore.Featurestore = None,
+ update_mask: field_mask_pb2.FieldMask = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Updates the parameters of a single Featurestore.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_update_featurestore():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.UpdateFeaturestoreRequest(
+ )
+
+ # Make the request
+ operation = client.update_featurestore(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.UpdateFeaturestoreRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.UpdateFeaturestore][google.cloud.aiplatform.v1.FeaturestoreService.UpdateFeaturestore].
+ featurestore (:class:`google.cloud.aiplatform_v1.types.Featurestore`):
+ Required. The Featurestore's ``name`` field is used to
+ identify the Featurestore to be updated. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}``
+
+ This corresponds to the ``featurestore`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
+ Field mask is used to specify the fields to be
+ overwritten in the Featurestore resource by the update.
+ The fields specified in the update_mask are relative to
+ the resource, not the full request. A field will be
+ overwritten if it is in the mask. If the user does not
+ provide a mask then only the non-empty fields present in
+ the request will be overwritten. Set the update_mask to
+ ``*`` to override all fields.
+
+ Updatable fields:
+
+ - ``labels``
+ - ``online_serving_config.fixed_node_count``
+ - ``online_serving_config.scaling``
+
+ This corresponds to the ``update_mask`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.Featurestore` Vertex AI Feature Store provides a centralized repository for organizing,
+ storing, and serving ML features. The Featurestore is
+ a top-level container for your features and their
+ values.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([featurestore, update_mask])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.UpdateFeaturestoreRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if featurestore is not None:
+ request.featurestore = featurestore
+ if update_mask is not None:
+ request.update_mask = update_mask
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.update_featurestore,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("featurestore.name", request.featurestore.name),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ gca_featurestore.Featurestore,
+ metadata_type=featurestore_service.UpdateFeaturestoreOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_featurestore(
+ self,
+ request: Union[featurestore_service.DeleteFeaturestoreRequest, dict] = None,
+ *,
+ name: str = None,
+ force: bool = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Deletes a single Featurestore. The Featurestore must not contain
+ any EntityTypes or ``force`` must be set to true for the request
+ to succeed.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_delete_featurestore():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteFeaturestoreRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_featurestore(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.DeleteFeaturestoreRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.DeleteFeaturestore][google.cloud.aiplatform.v1.FeaturestoreService.DeleteFeaturestore].
+ name (:class:`str`):
+ Required. The name of the Featurestore to be deleted.
+ Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}``
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ force (:class:`bool`):
+ If set to true, any EntityTypes and
+ Features for this Featurestore will also
+ be deleted. (Otherwise, the request will
+ only work if the Featurestore has no
+ EntityTypes.)
+
+ This corresponds to the ``force`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated
+ empty messages in your APIs. A typical example is to
+ use it as the request or the response type of an API
+ method. For instance:
+
+ service Foo {
+ rpc Bar(google.protobuf.Empty) returns
+ (google.protobuf.Empty);
+
+ }
+
+ The JSON representation for Empty is empty JSON
+ object {}.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name, force])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.DeleteFeaturestoreRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+ if force is not None:
+ request.force = force
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.delete_featurestore,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ empty_pb2.Empty,
+ metadata_type=gca_operation.DeleteOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def create_entity_type(
+ self,
+ request: Union[featurestore_service.CreateEntityTypeRequest, dict] = None,
+ *,
+ parent: str = None,
+ entity_type: gca_entity_type.EntityType = None,
+ entity_type_id: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Creates a new EntityType in a given Featurestore.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_create_entity_type():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.CreateEntityTypeRequest(
+ parent="parent_value",
+ entity_type_id="entity_type_id_value",
+ )
+
+ # Make the request
+ operation = client.create_entity_type(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.CreateEntityTypeRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.CreateEntityType][google.cloud.aiplatform.v1.FeaturestoreService.CreateEntityType].
+ parent (:class:`str`):
+ Required. The resource name of the Featurestore to
+ create EntityTypes. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ entity_type (:class:`google.cloud.aiplatform_v1.types.EntityType`):
+ The EntityType to create.
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ entity_type_id (:class:`str`):
+ Required. The ID to use for the EntityType, which will
+ become the final component of the EntityType's resource
+ name.
+
+ This value may be up to 60 characters, and valid
+ characters are ``[a-z0-9_]``. The first character cannot
+ be a number.
+
+ The value must be unique within a featurestore.
+
+ This corresponds to the ``entity_type_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.EntityType` An entity type is a type of object in a system that needs to be modeled and
+ have stored information about. For example, driver is
+ an entity type, and driver0 is an instance of an
+ entity type driver.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, entity_type, entity_type_id])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.CreateEntityTypeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if entity_type is not None:
+ request.entity_type = entity_type
+ if entity_type_id is not None:
+ request.entity_type_id = entity_type_id
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.create_entity_type,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ gca_entity_type.EntityType,
+ metadata_type=featurestore_service.CreateEntityTypeOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_entity_type(
+ self,
+ request: Union[featurestore_service.GetEntityTypeRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> entity_type.EntityType:
+ r"""Gets details of a single EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_entity_type():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetEntityTypeRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_entity_type(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.GetEntityTypeRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.GetEntityType][google.cloud.aiplatform.v1.FeaturestoreService.GetEntityType].
+ name (:class:`str`):
+ Required. The name of the EntityType resource. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.EntityType:
+ An entity type is a type of object in
+ a system that needs to be modeled and
+ have stored information about. For
+ example, driver is an entity type, and
+ driver0 is an instance of an entity type
+ driver.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.GetEntityTypeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_entity_type,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_entity_types(
+ self,
+ request: Union[featurestore_service.ListEntityTypesRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListEntityTypesAsyncPager:
+ r"""Lists EntityTypes in a given Featurestore.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_entity_types():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListEntityTypesRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_entity_types(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ListEntityTypesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.ListEntityTypes][google.cloud.aiplatform.v1.FeaturestoreService.ListEntityTypes].
+ parent (:class:`str`):
+ Required. The resource name of the Featurestore to list
+ EntityTypes. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.services.featurestore_service.pagers.ListEntityTypesAsyncPager:
+ Response message for
+ [FeaturestoreService.ListEntityTypes][google.cloud.aiplatform.v1.FeaturestoreService.ListEntityTypes].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.ListEntityTypesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_entity_types,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListEntityTypesAsyncPager(
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def update_entity_type(
+ self,
+ request: Union[featurestore_service.UpdateEntityTypeRequest, dict] = None,
+ *,
+ entity_type: gca_entity_type.EntityType = None,
+ update_mask: field_mask_pb2.FieldMask = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> gca_entity_type.EntityType:
+ r"""Updates the parameters of a single EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_update_entity_type():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.UpdateEntityTypeRequest(
+ )
+
+ # Make the request
+ response = await client.update_entity_type(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.UpdateEntityTypeRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.UpdateEntityType][google.cloud.aiplatform.v1.FeaturestoreService.UpdateEntityType].
+ entity_type (:class:`google.cloud.aiplatform_v1.types.EntityType`):
+ Required. The EntityType's ``name`` field is used to
+ identify the EntityType to be updated. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
+ Field mask is used to specify the fields to be
+ overwritten in the EntityType resource by the update.
+ The fields specified in the update_mask are relative to
+ the resource, not the full request. A field will be
+ overwritten if it is in the mask. If the user does not
+ provide a mask then only the non-empty fields present in
+ the request will be overwritten. Set the update_mask to
+ ``*`` to override all fields.
+
+ Updatable fields:
+
+ - ``description``
+ - ``labels``
+ - ``monitoring_config.snapshot_analysis.disabled``
+ - ``monitoring_config.snapshot_analysis.monitoring_interval_days``
+ - ``monitoring_config.snapshot_analysis.staleness_days``
+ - ``monitoring_config.import_features_analysis.state``
+ - ``monitoring_config.import_features_analysis.anomaly_detection_baseline``
+ - ``monitoring_config.numerical_threshold_config.value``
+ - ``monitoring_config.categorical_threshold_config.value``
+
+ This corresponds to the ``update_mask`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.EntityType:
+ An entity type is a type of object in
+ a system that needs to be modeled and
+ have stored information about. For
+ example, driver is an entity type, and
+ driver0 is an instance of an entity type
+ driver.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type, update_mask])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.UpdateEntityTypeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+ if update_mask is not None:
+ request.update_mask = update_mask
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.update_entity_type,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type.name", request.entity_type.name),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_entity_type(
+ self,
+ request: Union[featurestore_service.DeleteEntityTypeRequest, dict] = None,
+ *,
+ name: str = None,
+ force: bool = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Deletes a single EntityType. The EntityType must not have any
+ Features or ``force`` must be set to true for the request to
+ succeed.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_delete_entity_type():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteEntityTypeRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_entity_type(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.DeleteEntityTypeRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.DeleteEntityTypes][].
+ name (:class:`str`):
+ Required. The name of the EntityType to be deleted.
+ Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ force (:class:`bool`):
+ If set to true, any Features for this
+ EntityType will also be deleted.
+ (Otherwise, the request will only work
+ if the EntityType has no Features.)
+
+ This corresponds to the ``force`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated
+ empty messages in your APIs. A typical example is to
+ use it as the request or the response type of an API
+ method. For instance:
+
+ service Foo {
+ rpc Bar(google.protobuf.Empty) returns
+ (google.protobuf.Empty);
+
+ }
+
+ The JSON representation for Empty is empty JSON
+ object {}.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name, force])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.DeleteEntityTypeRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+ if force is not None:
+ request.force = force
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.delete_entity_type,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ empty_pb2.Empty,
+ metadata_type=gca_operation.DeleteOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def create_feature(
+ self,
+ request: Union[featurestore_service.CreateFeatureRequest, dict] = None,
+ *,
+ parent: str = None,
+ feature: gca_feature.Feature = None,
+ feature_id: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Creates a new Feature in a given EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_create_feature():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ feature = aiplatform_v1.Feature()
+ feature.value_type = "BYTES"
+
+ request = aiplatform_v1.CreateFeatureRequest(
+ parent="parent_value",
+ feature=feature,
+ feature_id="feature_id_value",
+ )
+
+ # Make the request
+ operation = client.create_feature(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.CreateFeatureRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.CreateFeature][google.cloud.aiplatform.v1.FeaturestoreService.CreateFeature].
+ parent (:class:`str`):
+ Required. The resource name of the EntityType to create
+ a Feature. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ feature (:class:`google.cloud.aiplatform_v1.types.Feature`):
+ Required. The Feature to create.
+ This corresponds to the ``feature`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ feature_id (:class:`str`):
+ Required. The ID to use for the Feature, which will
+ become the final component of the Feature's resource
+ name.
+
+ This value may be up to 60 characters, and valid
+ characters are ``[a-z0-9_]``. The first character cannot
+ be a number.
+
+ The value must be unique within an EntityType.
+
+ This corresponds to the ``feature_id`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.Feature` Feature Metadata information that describes an attribute of an entity type.
+ For example, apple is an entity type, and color is a
+ feature that describes apple.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, feature, feature_id])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.CreateFeatureRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if feature is not None:
+ request.feature = feature
+ if feature_id is not None:
+ request.feature_id = feature_id
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.create_feature,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ gca_feature.Feature,
+ metadata_type=featurestore_service.CreateFeatureOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def batch_create_features(
+ self,
+ request: Union[featurestore_service.BatchCreateFeaturesRequest, dict] = None,
+ *,
+ parent: str = None,
+ requests: Sequence[featurestore_service.CreateFeatureRequest] = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Creates a batch of Features in a given EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_batch_create_features():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ requests = aiplatform_v1.CreateFeatureRequest()
+ requests.parent = "parent_value"
+ requests.feature.value_type = "BYTES"
+ requests.feature_id = "feature_id_value"
+
+ request = aiplatform_v1.BatchCreateFeaturesRequest(
+ parent="parent_value",
+ requests=requests,
+ )
+
+ # Make the request
+ operation = client.batch_create_features(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.BatchCreateFeaturesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.BatchCreateFeatures][google.cloud.aiplatform.v1.FeaturestoreService.BatchCreateFeatures].
+ parent (:class:`str`):
+ Required. The resource name of the EntityType to create
+ the batch of Features under. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ requests (:class:`Sequence[google.cloud.aiplatform_v1.types.CreateFeatureRequest]`):
+ Required. The request message specifying the Features to
+ create. All Features must be created under the same
+ parent EntityType. The ``parent`` field in each child
+ request message can be omitted. If ``parent`` is set in
+ a child request, then the value must match the
+ ``parent`` value in this request message.
+
+ This corresponds to the ``requests`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.aiplatform_v1.types.BatchCreateFeaturesResponse`
+ Response message for
+ [FeaturestoreService.BatchCreateFeatures][google.cloud.aiplatform.v1.FeaturestoreService.BatchCreateFeatures].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent, requests])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.BatchCreateFeaturesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+ if requests:
+ request.requests.extend(requests)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.batch_create_features,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ featurestore_service.BatchCreateFeaturesResponse,
+ metadata_type=featurestore_service.BatchCreateFeaturesOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_feature(
+ self,
+ request: Union[featurestore_service.GetFeatureRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> feature.Feature:
+ r"""Gets details of a single Feature.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_get_feature():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.GetFeatureRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ response = await client.get_feature(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.GetFeatureRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.GetFeature][google.cloud.aiplatform.v1.FeaturestoreService.GetFeature].
+ name (:class:`str`):
+ Required. The name of the Feature resource. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.Feature:
+ Feature Metadata information that
+ describes an attribute of an entity
+ type. For example, apple is an entity
+ type, and color is a feature that
+ describes apple.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.GetFeatureRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.get_feature,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_features(
+ self,
+ request: Union[featurestore_service.ListFeaturesRequest, dict] = None,
+ *,
+ parent: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.ListFeaturesAsyncPager:
+ r"""Lists Features in a given EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_list_features():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.ListFeaturesRequest(
+ parent="parent_value",
+ )
+
+ # Make the request
+ page_result = client.list_features(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ListFeaturesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.ListFeatures][google.cloud.aiplatform.v1.FeaturestoreService.ListFeatures].
+ parent (:class:`str`):
+ Required. The resource name of the Location to list
+ Features. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``parent`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.services.featurestore_service.pagers.ListFeaturesAsyncPager:
+ Response message for
+ [FeaturestoreService.ListFeatures][google.cloud.aiplatform.v1.FeaturestoreService.ListFeatures].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([parent])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.ListFeaturesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if parent is not None:
+ request.parent = parent
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.list_features,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.ListFeaturesAsyncPager(
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def update_feature(
+ self,
+ request: Union[featurestore_service.UpdateFeatureRequest, dict] = None,
+ *,
+ feature: gca_feature.Feature = None,
+ update_mask: field_mask_pb2.FieldMask = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> gca_feature.Feature:
+ r"""Updates the parameters of a single Feature.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_update_feature():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ feature = aiplatform_v1.Feature()
+ feature.value_type = "BYTES"
+
+ request = aiplatform_v1.UpdateFeatureRequest(
+ feature=feature,
+ )
+
+ # Make the request
+ response = await client.update_feature(request=request)
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.UpdateFeatureRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.UpdateFeature][google.cloud.aiplatform.v1.FeaturestoreService.UpdateFeature].
+ feature (:class:`google.cloud.aiplatform_v1.types.Feature`):
+ Required. The Feature's ``name`` field is used to
+ identify the Feature to be updated. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}``
+
+ This corresponds to the ``feature`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`):
+ Field mask is used to specify the fields to be
+ overwritten in the Features resource by the update. The
+ fields specified in the update_mask are relative to the
+ resource, not the full request. A field will be
+ overwritten if it is in the mask. If the user does not
+ provide a mask then only the non-empty fields present in
+ the request will be overwritten. Set the update_mask to
+ ``*`` to override all fields.
+
+ Updatable fields:
+
+ - ``description``
+ - ``labels``
+ - ``disable_monitoring``
+
+ This corresponds to the ``update_mask`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.types.Feature:
+ Feature Metadata information that
+ describes an attribute of an entity
+ type. For example, apple is an entity
+ type, and color is a feature that
+ describes apple.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([feature, update_mask])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.UpdateFeatureRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if feature is not None:
+ request.feature = feature
+ if update_mask is not None:
+ request.update_mask = update_mask
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.update_feature,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("feature.name", request.feature.name),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_feature(
+ self,
+ request: Union[featurestore_service.DeleteFeatureRequest, dict] = None,
+ *,
+ name: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Deletes a single Feature.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_delete_feature():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.DeleteFeatureRequest(
+ name="name_value",
+ )
+
+ # Make the request
+ operation = client.delete_feature(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.DeleteFeatureRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.DeleteFeature][google.cloud.aiplatform.v1.FeaturestoreService.DeleteFeature].
+ name (:class:`str`):
+ Required. The name of the Features to be deleted.
+ Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}``
+
+ This corresponds to the ``name`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated
+ empty messages in your APIs. A typical example is to
+ use it as the request or the response type of an API
+ method. For instance:
+
+ service Foo {
+ rpc Bar(google.protobuf.Empty) returns
+ (google.protobuf.Empty);
+
+ }
+
+ The JSON representation for Empty is empty JSON
+ object {}.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([name])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.DeleteFeatureRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if name is not None:
+ request.name = name
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.delete_feature,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ empty_pb2.Empty,
+ metadata_type=gca_operation.DeleteOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def import_feature_values(
+ self,
+ request: Union[featurestore_service.ImportFeatureValuesRequest, dict] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Imports Feature values into the Featurestore from a
+ source storage.
+ The progress of the import is tracked by the returned
+ operation. The imported features are guaranteed to be
+ visible to subsequent read operations after the
+ operation is marked as successfully done.
+ If an import operation fails, the Feature values
+ returned from reads and exports may be inconsistent. If
+ consistency is required, the caller must retry the same
+ import request again and wait till the new operation
+ returned is marked as successfully done.
+ There are also scenarios where the caller can cause
+ inconsistency.
+ - Source data for import contains multiple distinct
+ Feature values for the same entity ID and timestamp.
+ - Source is modified during an import. This includes
+ adding, updating, or removing source data and/or
+ metadata. Examples of updating metadata include but are
+ not limited to changing storage location, storage class,
+ or retention policy.
+ - Online serving cluster is under-provisioned.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_import_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ avro_source = aiplatform_v1.AvroSource()
+ avro_source.gcs_source.uris = ['uris_value_1', 'uris_value_2']
+
+ feature_specs = aiplatform_v1.FeatureSpec()
+ feature_specs.id = "id_value"
+
+ request = aiplatform_v1.ImportFeatureValuesRequest(
+ avro_source=avro_source,
+ feature_time_field="feature_time_field_value",
+ entity_type="entity_type_value",
+ feature_specs=feature_specs,
+ )
+
+ # Make the request
+ operation = client.import_feature_values(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ImportFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.ImportFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.ImportFeatureValues].
+ entity_type (:class:`str`):
+ Required. The resource name of the EntityType grouping
+ the Features for which values are being imported.
+ Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entityType}``
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.aiplatform_v1.types.ImportFeatureValuesResponse`
+ Response message for
+ [FeaturestoreService.ImportFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.ImportFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.ImportFeatureValuesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.import_feature_values,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ featurestore_service.ImportFeatureValuesResponse,
+ metadata_type=featurestore_service.ImportFeatureValuesOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def batch_read_feature_values(
+ self,
+ request: Union[featurestore_service.BatchReadFeatureValuesRequest, dict] = None,
+ *,
+ featurestore: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Batch reads Feature values from a Featurestore.
+ This API enables batch reading Feature values, where
+ each read instance in the batch may read Feature values
+ of entities from one or more EntityTypes. Point-in-time
+ correctness is guaranteed for Feature values of each
+ read instance as of each instance's read timestamp.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_batch_read_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ csv_read_instances = aiplatform_v1.CsvSource()
+ csv_read_instances.gcs_source.uris = ['uris_value_1', 'uris_value_2']
+
+ destination = aiplatform_v1.FeatureValueDestination()
+ destination.bigquery_destination.output_uri = "output_uri_value"
+
+ entity_type_specs = aiplatform_v1.EntityTypeSpec()
+ entity_type_specs.entity_type_id = "entity_type_id_value"
+ entity_type_specs.feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.BatchReadFeatureValuesRequest(
+ csv_read_instances=csv_read_instances,
+ featurestore="featurestore_value",
+ destination=destination,
+ entity_type_specs=entity_type_specs,
+ )
+
+ # Make the request
+ operation = client.batch_read_feature_values(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.BatchReadFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.BatchReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.BatchReadFeatureValues].
+ featurestore (:class:`str`):
+ Required. The resource name of the Featurestore from
+ which to query Feature values. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}``
+
+ This corresponds to the ``featurestore`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.aiplatform_v1.types.BatchReadFeatureValuesResponse`
+ Response message for
+ [FeaturestoreService.BatchReadFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.BatchReadFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([featurestore])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.BatchReadFeatureValuesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if featurestore is not None:
+ request.featurestore = featurestore
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.batch_read_feature_values,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("featurestore", request.featurestore),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ featurestore_service.BatchReadFeatureValuesResponse,
+ metadata_type=featurestore_service.BatchReadFeatureValuesOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def export_feature_values(
+ self,
+ request: Union[featurestore_service.ExportFeatureValuesRequest, dict] = None,
+ *,
+ entity_type: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operation_async.AsyncOperation:
+ r"""Exports Feature values from all the entities of a
+ target EntityType.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_export_feature_values():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ destination = aiplatform_v1.FeatureValueDestination()
+ destination.bigquery_destination.output_uri = "output_uri_value"
+
+ feature_selector = aiplatform_v1.FeatureSelector()
+ feature_selector.id_matcher.ids = ['ids_value_1', 'ids_value_2']
+
+ request = aiplatform_v1.ExportFeatureValuesRequest(
+ entity_type="entity_type_value",
+ destination=destination,
+ feature_selector=feature_selector,
+ )
+
+ # Make the request
+ operation = client.export_feature_values(request=request)
+
+ print("Waiting for operation to complete...")
+
+ response = await operation.result()
+
+ # Handle the response
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.ExportFeatureValuesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.ExportFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.ExportFeatureValues].
+ entity_type (:class:`str`):
+ Required. The resource name of the EntityType from which
+ to export Feature values. Format:
+ ``projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}``
+
+ This corresponds to the ``entity_type`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.api_core.operation_async.AsyncOperation:
+ An object representing a long-running operation.
+
+ The result type for the operation will be
+ :class:`google.cloud.aiplatform_v1.types.ExportFeatureValuesResponse`
+ Response message for
+ [FeaturestoreService.ExportFeatureValues][google.cloud.aiplatform.v1.FeaturestoreService.ExportFeatureValues].
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([entity_type])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.ExportFeatureValuesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if entity_type is not None:
+ request.entity_type = entity_type
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.export_feature_values,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata(
+ (("entity_type", request.entity_type),)
+ ),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Wrap the response in an operation future.
+ response = operation_async.from_gapic(
+ response,
+ self._client._transport.operations_client,
+ featurestore_service.ExportFeatureValuesResponse,
+ metadata_type=featurestore_service.ExportFeatureValuesOperationMetadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def search_features(
+ self,
+ request: Union[featurestore_service.SearchFeaturesRequest, dict] = None,
+ *,
+ location: str = None,
+ query: str = None,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> pagers.SearchFeaturesAsyncPager:
+ r"""Searches Features matching a query in a given
+ project.
+
+ .. code-block:: python
+
+ from google.cloud import aiplatform_v1
+
+ async def sample_search_features():
+ # Create a client
+ client = aiplatform_v1.FeaturestoreServiceAsyncClient()
+
+ # Initialize request argument(s)
+ request = aiplatform_v1.SearchFeaturesRequest(
+ location="location_value",
+ )
+
+ # Make the request
+ page_result = client.search_features(request=request)
+
+ # Handle the response
+ async for response in page_result:
+ print(response)
+
+ Args:
+ request (Union[google.cloud.aiplatform_v1.types.SearchFeaturesRequest, dict]):
+ The request object. Request message for
+ [FeaturestoreService.SearchFeatures][google.cloud.aiplatform.v1.FeaturestoreService.SearchFeatures].
+ location (:class:`str`):
+ Required. The resource name of the Location to search
+ Features. Format:
+ ``projects/{project}/locations/{location}``
+
+ This corresponds to the ``location`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ query (:class:`str`):
+ Query string that is a conjunction of field-restricted
+ queries and/or field-restricted filters.
+ Field-restricted queries and filters can be combined
+ using ``AND`` to form a conjunction.
+
+ A field query is in the form FIELD:QUERY. This
+ implicitly checks if QUERY exists as a substring within
+ Feature's FIELD. The QUERY and the FIELD are converted
+ to a sequence of words (i.e. tokens) for comparison.
+ This is done by:
+
+ - Removing leading/trailing whitespace and tokenizing
+ the search value. Characters that are not one of
+ alphanumeric ``[a-zA-Z0-9]``, underscore ``_``, or
+ asterisk ``*`` are treated as delimiters for tokens.
+ ``*`` is treated as a wildcard that matches
+ characters within a token.
+ - Ignoring case.
+ - Prepending an asterisk to the first and appending an
+ asterisk to the last token in QUERY.
+
+ A QUERY must be either a singular token or a phrase. A
+ phrase is one or multiple words enclosed in double
+ quotation marks ("). With phrases, the order of the
+ words is important. Words in the phrase must be matching
+ in order and consecutively.
+
+ Supported FIELDs for field-restricted queries:
+
+ - ``feature_id``
+ - ``description``
+ - ``entity_type_id``
+
+ Examples:
+
+ - ``feature_id: foo`` --> Matches a Feature with ID
+ containing the substring ``foo`` (eg. ``foo``,
+ ``foofeature``, ``barfoo``).
+ - ``feature_id: foo*feature`` --> Matches a Feature
+ with ID containing the substring ``foo*feature`` (eg.
+ ``foobarfeature``).
+ - ``feature_id: foo AND description: bar`` --> Matches
+ a Feature with ID containing the substring ``foo``
+ and description containing the substring ``bar``.
+
+ Besides field queries, the following exact-match filters
+ are supported. The exact-match filters do not support
+ wildcards. Unlike field-restricted queries, exact-match
+ filters are case-sensitive.
+
+ - ``feature_id``: Supports = comparisons.
+ - ``description``: Supports = comparisons. Multi-token
+ filters should be enclosed in quotes.
+ - ``entity_type_id``: Supports = comparisons.
+ - ``value_type``: Supports = and != comparisons.
+ - ``labels``: Supports key-value equality as well as
+ key presence.
+ - ``featurestore_id``: Supports = comparisons.
+
+ Examples:
+
+ - ``description = "foo bar"`` --> Any Feature with
+ description exactly equal to ``foo bar``
+ - ``value_type = DOUBLE`` --> Features whose type is
+ DOUBLE.
+ - ``labels.active = yes AND labels.env = prod`` -->
+ Features having both (active: yes) and (env: prod)
+ labels.
+ - ``labels.env: *`` --> Any Feature which has a label
+ with ``env`` as the key.
+
+ This corresponds to the ``query`` field
+ on the ``request`` instance; if ``request`` is provided, this
+ should not be set.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+
+ Returns:
+ google.cloud.aiplatform_v1.services.featurestore_service.pagers.SearchFeaturesAsyncPager:
+ Response message for
+ [FeaturestoreService.SearchFeatures][google.cloud.aiplatform.v1.FeaturestoreService.SearchFeatures].
+
+ Iterating over this object will yield results and
+ resolve additional pages automatically.
+
+ """
+ # Create or coerce a protobuf request object.
+ # Quick check: If we got a request object, we should *not* have
+ # gotten any keyword arguments that map to the request.
+ has_flattened_params = any([location, query])
+ if request is not None and has_flattened_params:
+ raise ValueError(
+ "If the `request` argument is set, then none of "
+ "the individual field arguments should be set."
+ )
+
+ request = featurestore_service.SearchFeaturesRequest(request)
+
+ # If we have keyword arguments corresponding to fields on the
+ # request, apply these.
+ if location is not None:
+ request.location = location
+ if query is not None:
+ request.query = query
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method_async.wrap_method(
+ self._client._transport.search_features,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("location", request.location),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # This method is paged; wrap the response in a pager, which provides
+ # an `__aiter__` convenience method.
+ response = pagers.SearchFeaturesAsyncPager(
+ method=rpc,
+ request=request,
+ response=response,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_operations(
+ self,
+ request: operations_pb2.ListOperationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.ListOperationsResponse:
+ r"""Lists operations that match the specified filter in the request.
+
+ Args:
+ request (:class:`~.operations_pb2.ListOperationsRequest`):
+ The request object. Request message for
+ `ListOperations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.ListOperationsResponse:
+ Response message for ``ListOperations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.ListOperationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_operations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_operation(
+ self,
+ request: operations_pb2.GetOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Gets the latest state of a long-running operation.
+
+ Args:
+ request (:class:`~.operations_pb2.GetOperationRequest`):
+ The request object. Request message for
+ `GetOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.GetOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def delete_operation(
+ self,
+ request: operations_pb2.DeleteOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Deletes a long-running operation.
+
+ This method indicates that the client is no longer interested
+ in the operation result. It does not cancel the operation.
+ If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.DeleteOperationRequest`):
+ The request object. Request message for
+ `DeleteOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.DeleteOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.delete_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def cancel_operation(
+ self,
+ request: operations_pb2.CancelOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> None:
+ r"""Starts asynchronous cancellation on a long-running operation.
+
+ The server makes a best effort to cancel the operation, but success
+ is not guaranteed. If the server doesn't support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.CancelOperationRequest`):
+ The request object. Request message for
+ `CancelOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ None
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.CancelOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.cancel_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ async def wait_operation(
+ self,
+ request: operations_pb2.WaitOperationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> operations_pb2.Operation:
+ r"""Waits until the specified long-running operation is done or reaches at most
+ a specified timeout, returning the latest state.
+
+ If the operation is already done, the latest state is immediately returned.
+ If the timeout specified is greater than the default HTTP/RPC timeout, the HTTP/RPC
+ timeout is used. If the server does not support this method, it returns
+ `google.rpc.Code.UNIMPLEMENTED`.
+
+ Args:
+ request (:class:`~.operations_pb2.WaitOperationRequest`):
+ The request object. Request message for
+ `WaitOperation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.operations_pb2.Operation:
+ An ``Operation`` object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = operations_pb2.WaitOperationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.wait_operation,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def set_iam_policy(
+ self,
+ request: iam_policy_pb2.SetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Sets the IAM access control policy on the specified function.
+
+ Replaces any existing policy.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.SetIamPolicyRequest`):
+ The request object. Request message for `SetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if any,
+ should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.SetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.set_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_iam_policy(
+ self,
+ request: iam_policy_pb2.GetIamPolicyRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> policy_pb2.Policy:
+ r"""Gets the IAM access control policy for a function.
+
+ Returns an empty policy if the function exists and does not have a
+ policy set.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.GetIamPolicyRequest`):
+ The request object. Request message for `GetIamPolicy`
+ method.
+ retry (google.api_core.retry.Retry): Designation of what errors, if
+ any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.policy_pb2.Policy:
+ Defines an Identity and Access Management (IAM) policy.
+ It is used to specify access control policies for Cloud
+ Platform resources.
+ A ``Policy`` is a collection of ``bindings``. A
+ ``binding`` binds one or more ``members`` to a single
+ ``role``. Members can be user accounts, service
+ accounts, Google groups, and domains (such as G Suite).
+ A ``role`` is a named list of permissions (defined by
+ IAM or configured by users). A ``binding`` can
+ optionally specify a ``condition``, which is a logic
+ expression that further constrains the role binding
+ based on attributes about the request and/or target
+ resource.
+ **JSON Example**
+ ::
+ {
+ "bindings": [
+ {
+ "role": "roles/resourcemanager.organizationAdmin",
+ "members": [
+ "user:mike@example.com",
+ "group:admins@example.com",
+ "domain:google.com",
+ "serviceAccount:my-project-id@appspot.gserviceaccount.com"
+ ]
+ },
+ {
+ "role": "roles/resourcemanager.organizationViewer",
+ "members": ["user:eve@example.com"],
+ "condition": {
+ "title": "expirable access",
+ "description": "Does not grant access after Sep 2020",
+ "expression": "request.time <
+ timestamp('2020-10-01T00:00:00.000Z')",
+ }
+ }
+ ]
+ }
+ **YAML Example**
+ ::
+ bindings:
+ - members:
+ - user:mike@example.com
+ - group:admins@example.com
+ - domain:google.com
+ - serviceAccount:my-project-id@appspot.gserviceaccount.com
+ role: roles/resourcemanager.organizationAdmin
+ - members:
+ - user:eve@example.com
+ role: roles/resourcemanager.organizationViewer
+ condition:
+ title: expirable access
+ description: Does not grant access after Sep 2020
+ expression: request.time < timestamp('2020-10-01T00:00:00.000Z')
+ For a description of IAM and its features, see the `IAM
+ developer's
+ guide `__.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.GetIamPolicyRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_iam_policy,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def test_iam_permissions(
+ self,
+ request: iam_policy_pb2.TestIamPermissionsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> iam_policy_pb2.TestIamPermissionsResponse:
+ r"""Tests the specified IAM permissions against the IAM access control
+ policy for a function.
+
+ If the function does not exist, this will return an empty set
+ of permissions, not a NOT_FOUND error.
+
+ Args:
+ request (:class:`~.iam_policy_pb2.TestIamPermissionsRequest`):
+ The request object. Request message for
+ `TestIamPermissions` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.iam_policy_pb2.TestIamPermissionsResponse:
+ Response message for ``TestIamPermissions`` method.
+ """
+ # Create or coerce a protobuf request object.
+
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = iam_policy_pb2.TestIamPermissionsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.test_iam_permissions,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("resource", request.resource),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def get_location(
+ self,
+ request: locations_pb2.GetLocationRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.Location:
+ r"""Gets information about a location.
+
+ Args:
+ request (:class:`~.location_pb2.GetLocationRequest`):
+ The request object. Request message for
+ `GetLocation` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.Location:
+ Location object.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.GetLocationRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.get_location,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def list_locations(
+ self,
+ request: locations_pb2.ListLocationsRequest = None,
+ *,
+ retry: OptionalRetry = gapic_v1.method.DEFAULT,
+ timeout: float = None,
+ metadata: Sequence[Tuple[str, str]] = (),
+ ) -> locations_pb2.ListLocationsResponse:
+ r"""Lists information about the supported locations for this service.
+
+ Args:
+ request (:class:`~.location_pb2.ListLocationsRequest`):
+ The request object. Request message for
+ `ListLocations` method.
+ retry (google.api_core.retry.Retry): Designation of what errors,
+ if any, should be retried.
+ timeout (float): The timeout for this request.
+ metadata (Sequence[Tuple[str, str]]): Strings which should be
+ sent along with the request as metadata.
+ Returns:
+ ~.location_pb2.ListLocationsResponse:
+ Response message for ``ListLocations`` method.
+ """
+ # Create or coerce a protobuf request object.
+ # The request isn't a proto-plus wrapped type,
+ # so it must be constructed via keyword expansion.
+ if isinstance(request, dict):
+ request = locations_pb2.ListLocationsRequest(**request)
+
+ # Wrap the RPC method; this adds retry and timeout information,
+ # and friendly error handling.
+ rpc = gapic_v1.method.wrap_method(
+ self._client._transport.list_locations,
+ default_timeout=None,
+ client_info=DEFAULT_CLIENT_INFO,
+ )
+
+ # Certain fields should be provided within the metadata header;
+ # add these here.
+ metadata = tuple(metadata) + (
+ gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)),
+ )
+
+ # Send the request.
+ response = await rpc(
+ request,
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ # Done; return the response.
+ return response
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.transport.close()
+
+
+try:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(
+ gapic_version=pkg_resources.get_distribution(
+ "google-cloud-aiplatform",
+ ).version,
+ )
+except pkg_resources.DistributionNotFound:
+ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()
+
+
+__all__ = ("FeaturestoreServiceAsyncClient",)
diff --git a/google/cloud/aiplatform_v1/services/featurestore_service/client.py b/google/cloud/aiplatform_v1/services/featurestore_service/client.py
new file mode 100644
index 0000000000..e75b3f4f1e
--- /dev/null
+++ b/google/cloud/aiplatform_v1/services/featurestore_service/client.py
@@ -0,0 +1,3738 @@
+# -*- coding: utf-8 -*-
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from collections import OrderedDict
+import os
+import re
+from typing import Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+import pkg_resources
+
+from google.api_core import client_options as client_options_lib
+from google.api_core import exceptions as core_exceptions
+from google.api_core import gapic_v1
+from google.api_core import retry as retries
+from google.auth import credentials as ga_credentials # type: ignore
+from google.auth.transport import mtls # type: ignore
+from google.auth.transport.grpc import SslCredentials # type: ignore
+from google.auth.exceptions import MutualTLSChannelError # type: ignore
+from google.oauth2 import service_account # type: ignore
+
+try:
+ OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault]
+except AttributeError: # pragma: NO COVER
+ OptionalRetry = Union[retries.Retry, object] # type: ignore
+
+from google.api_core import operation as gac_operation # type: ignore
+from google.api_core import operation_async # type: ignore
+from google.cloud.aiplatform_v1.services.featurestore_service import pagers
+from google.cloud.aiplatform_v1.types import encryption_spec
+from google.cloud.aiplatform_v1.types import entity_type
+from google.cloud.aiplatform_v1.types import entity_type as gca_entity_type
+from google.cloud.aiplatform_v1.types import feature
+from google.cloud.aiplatform_v1.types import feature as gca_feature
+from google.cloud.aiplatform_v1.types import featurestore
+from google.cloud.aiplatform_v1.types import featurestore as gca_featurestore
+from google.cloud.aiplatform_v1.types import featurestore_monitoring
+from google.cloud.aiplatform_v1.types import featurestore_service
+from google.cloud.aiplatform_v1.types import operation as gca_operation
+from google.cloud.location import locations_pb2 # type: ignore
+from google.iam.v1 import iam_policy_pb2 # type: ignore
+from google.iam.v1 import policy_pb2 # type: ignore
+from google.longrunning import operations_pb2
+from google.protobuf import empty_pb2 # type: ignore
+from google.protobuf import field_mask_pb2 # type: ignore
+from google.protobuf import timestamp_pb2 # type: ignore
+from .transports.base import FeaturestoreServiceTransport, DEFAULT_CLIENT_INFO
+from .transports.grpc import FeaturestoreServiceGrpcTransport
+from .transports.grpc_asyncio import FeaturestoreServiceGrpcAsyncIOTransport
+
+
+class FeaturestoreServiceClientMeta(type):
+ """Metaclass for the FeaturestoreService client.
+
+ This provides class-level methods for building and retrieving
+ support objects (e.g. transport) without polluting the client instance
+ objects.
+ """
+
+ _transport_registry = (
+ OrderedDict()
+ ) # type: Dict[str, Type[FeaturestoreServiceTransport]]
+ _transport_registry["grpc"] = FeaturestoreServiceGrpcTransport
+ _transport_registry["grpc_asyncio"] = FeaturestoreServiceGrpcAsyncIOTransport
+
+ def get_transport_class(
+ cls,
+ label: str = None,
+ ) -> Type[FeaturestoreServiceTransport]:
+ """Returns an appropriate transport class.
+
+ Args:
+ label: The name of the desired transport. If none is
+ provided, then the first transport in the registry is used.
+
+ Returns:
+ The transport class to use.
+ """
+ # If a specific transport is requested, return that one.
+ if label:
+ return cls._transport_registry[label]
+
+ # No transport is requested; return the default (that is, the first one
+ # in the dictionary).
+ return next(iter(cls._transport_registry.values()))
+
+
+class FeaturestoreServiceClient(metaclass=FeaturestoreServiceClientMeta):
+ """The service that handles CRUD and List for resources for
+ Featurestore.
+ """
+
+ @staticmethod
+ def _get_default_mtls_endpoint(api_endpoint):
+ """Converts api endpoint to mTLS endpoint.
+
+ Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to
+ "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively.
+ Args:
+ api_endpoint (Optional[str]): the api endpoint to convert.
+ Returns:
+ str: converted mTLS api endpoint.
+ """
+ if not api_endpoint:
+ return api_endpoint
+
+ mtls_endpoint_re = re.compile(
+ r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?"
+ )
+
+ m = mtls_endpoint_re.match(api_endpoint)
+ name, mtls, sandbox, googledomain = m.groups()
+ if mtls or not googledomain:
+ return api_endpoint
+
+ if sandbox:
+ return api_endpoint.replace(
+ "sandbox.googleapis.com", "mtls.sandbox.googleapis.com"
+ )
+
+ return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com")
+
+ DEFAULT_ENDPOINT = "aiplatform.googleapis.com"
+ DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore
+ DEFAULT_ENDPOINT
+ )
+
+ @classmethod
+ def from_service_account_info(cls, info: dict, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ info.
+
+ Args:
+ info (dict): The service account private key info.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreServiceClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_info(info)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ @classmethod
+ def from_service_account_file(cls, filename: str, *args, **kwargs):
+ """Creates an instance of this client using the provided credentials
+ file.
+
+ Args:
+ filename (str): The path to the service account private key json
+ file.
+ args: Additional arguments to pass to the constructor.
+ kwargs: Additional arguments to pass to the constructor.
+
+ Returns:
+ FeaturestoreServiceClient: The constructed client.
+ """
+ credentials = service_account.Credentials.from_service_account_file(filename)
+ kwargs["credentials"] = credentials
+ return cls(*args, **kwargs)
+
+ from_service_account_json = from_service_account_file
+
+ @property
+ def transport(self) -> FeaturestoreServiceTransport:
+ """Returns the transport used by the client instance.
+
+ Returns:
+ FeaturestoreServiceTransport: The transport used by the client
+ instance.
+ """
+ return self._transport
+
+ @staticmethod
+ def entity_type_path(
+ project: str,
+ location: str,
+ featurestore: str,
+ entity_type: str,
+ ) -> str:
+ """Returns a fully-qualified entity_type string."""
+ return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}".format(
+ project=project,
+ location=location,
+ featurestore=featurestore,
+ entity_type=entity_type,
+ )
+
+ @staticmethod
+ def parse_entity_type_path(path: str) -> Dict[str, str]:
+ """Parses a entity_type path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P.+?)/entityTypes/(?P.+?)$",
+ path,
+ )
+ return m.groupdict() if m else {}
+
+ @staticmethod
+ def feature_path(
+ project: str,
+ location: str,
+ featurestore: str,
+ entity_type: str,
+ feature: str,
+ ) -> str:
+ """Returns a fully-qualified feature string."""
+ return "projects/{project}/locations/{location}/featurestores/{featurestore}/entityTypes/{entity_type}/features/{feature}".format(
+ project=project,
+ location=location,
+ featurestore=featurestore,
+ entity_type=entity_type,
+ feature=feature,
+ )
+
+ @staticmethod
+ def parse_feature_path(path: str) -> Dict[str, str]:
+ """Parses a feature path into its component segments."""
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/featurestores/(?P