diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml deleted file mode 100644 index 508ba98efe..0000000000 --- a/.github/.OwlBot.lock.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -docker: - image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:25de45b58e52021d3a24a6273964371a97a4efeefe6ad3845a64e697c63b6447 -# created: 2025-04-14T14:34:43.260858345Z diff --git a/.github/.OwlBot.yaml b/.github/.OwlBot.yaml deleted file mode 100644 index b720d256ad..0000000000 --- a/.github/.OwlBot.yaml +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -docker: - image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - -deep-remove-regex: - - /owl-bot-staging - -deep-preserve-regex: - - /owl-bot-staging/firestore/v1beta1 - -deep-copy-regex: - - source: /google/firestore/(v.*)/.*-py/(.*) - dest: /owl-bot-staging/firestore/$1/$2 - - source: /google/firestore/admin/(v.*)/.*-py/(.*) - dest: /owl-bot-staging/firestore_admin/$1/$2 - - source: /google/firestore/bundle/(.*-py)/(.*) - dest: /owl-bot-staging/firestore_bundle/$1/$2 - -begin-after-commit-hash: 107ed1217b5e87048263f52cd3911d5f851aca7e - diff --git a/.github/auto-approve.yml b/.github/auto-approve.yml deleted file mode 100644 index 311ebbb853..0000000000 --- a/.github/auto-approve.yml +++ /dev/null @@ -1,3 +0,0 @@ -# https://github.com/googleapis/repo-automation-bots/tree/main/packages/auto-approve -processes: - - "OwlBotTemplateChanges" diff --git a/.github/release-please.yml b/.github/release-please.yml deleted file mode 100644 index fe749ff6b1..0000000000 --- a/.github/release-please.yml +++ /dev/null @@ -1,12 +0,0 @@ -releaseType: python -handleGHRelease: true -manifest: true -# NOTE: this section is generated by synthtool.languages.python -# See https://github.com/googleapis/synthtool/blob/master/synthtool/languages/python.py -branches: -- branch: v1 - handleGHRelease: true - releaseType: python -- branch: v0 - handleGHRelease: true - releaseType: python diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml deleted file mode 100644 index 95896588a9..0000000000 --- a/.github/release-trigger.yml +++ /dev/null @@ -1,2 +0,0 @@ -enabled: true -multiScmName: python-firestore diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml deleted file mode 100644 index 9920db74d5..0000000000 --- a/.github/sync-repo-settings.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# Rules for main branch protection -branchProtectionRules: -# Identifies the protection rule pattern. Name of the branch to be protected. -# Defaults to `main` -- pattern: main - # Can admins overwrite branch protection. - # Defaults to `true` - isAdminEnforced: true - # Number of approving reviews required to update matching branches. - # Defaults to `1` - requiredApprovingReviewCount: 1 - # Are reviews from code owners required to update matching branches. - # Defaults to `false` - requiresCodeOwnerReviews: true - # Require up to date branches - requiresStrictStatusChecks: true - # List of required status check contexts that must pass for commits to be accepted to matching branches. - requiredStatusCheckContexts: - - 'Kokoro' - - 'Kokoro system-3.7' - - 'cla/google' - - 'OwlBot Post Processor' - - 'docs' - - 'docfx' - - 'lint' - - 'unit (3.7)' - - 'unit (3.8)' - - 'unit (3.9)' - - 'unit (3.10)' - - 'unit (3.11)' - - 'unit (3.12)' - - 'cover' - - 'run-systests' -# List of explicit permissions to add (additive only) -permissionRules: - # Team slug to add to repository permissions - - team: yoshi-admins - # Access level required, one of push|pull|admin|maintain|triage - permission: admin - # Team slug to add to repository permissions - - team: yoshi-python-admins - # Access level required, one of push|pull|admin|maintain|triage - permission: admin - # Team slug to add to repository permissions - - team: yoshi-python - # Access level required, one of push|pull|admin|maintain|triage - permission: push diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4866193af2..3ed755f000 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,7 +12,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.14" - name: Install nox run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 772186478f..4997affc75 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -12,7 +12,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.14" - name: Install nox run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml index 0f3a69224b..62a879072e 100644 --- a/.github/workflows/system_emulated.yml +++ b/.github/workflows/system_emulated.yml @@ -17,7 +17,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.14' # firestore emulator requires java 21+ - name: Setup Java diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index c66b757ced..cc6fe2b2fd 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-22.04 strategy: matrix: - python: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] + python: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13', '3.14'] steps: - name: Checkout uses: actions/checkout@v4 @@ -45,7 +45,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.14" - name: Install coverage run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.kokoro/presubmit/system-3.7.cfg b/.kokoro/presubmit/system.cfg similarity index 81% rename from .kokoro/presubmit/system-3.7.cfg rename to .kokoro/presubmit/system.cfg index 461537b3fb..73904141ba 100644 --- a/.kokoro/presubmit/system-3.7.cfg +++ b/.kokoro/presubmit/system.cfg @@ -3,5 +3,5 @@ # Only run this nox session. env_vars: { key: "NOX_SESSION" - value: "system-3.7" -} \ No newline at end of file + value: "system-3.14" +} diff --git a/.kokoro/samples/python3.14/common.cfg b/.kokoro/samples/python3.14/common.cfg new file mode 100644 index 0000000000..4e07d3590b --- /dev/null +++ b/.kokoro/samples/python3.14/common.cfg @@ -0,0 +1,40 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Build logs will be here +action { + define_artifacts { + regex: "**/*sponge_log.xml" + } +} + +# Specify which tests to run +env_vars: { + key: "RUN_TESTS_SESSION" + value: "py-3.14" +} + +# Declare build specific Cloud project. +env_vars: { + key: "BUILD_SPECIFIC_GCLOUD_PROJECT" + value: "python-docs-samples-tests-314" +} + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/python-firestore/.kokoro/test-samples.sh" +} + +# Configure the docker image for kokoro-trampoline. +env_vars: { + key: "TRAMPOLINE_IMAGE" + value: "gcr.io/cloud-devrel-kokoro-resources/python-samples-testing-docker" +} + +# Download secrets for samples +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples" + +# Download trampoline resources. +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline" + +# Use the trampoline script to run in docker. +build_file: "python-firestore/.kokoro/trampoline_v2.sh" diff --git a/.kokoro/samples/python3.14/continuous.cfg b/.kokoro/samples/python3.14/continuous.cfg new file mode 100644 index 0000000000..a1c8d9759c --- /dev/null +++ b/.kokoro/samples/python3.14/continuous.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.kokoro/samples/python3.14/periodic-head.cfg b/.kokoro/samples/python3.14/periodic-head.cfg new file mode 100644 index 0000000000..21998d0902 --- /dev/null +++ b/.kokoro/samples/python3.14/periodic-head.cfg @@ -0,0 +1,11 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" +} diff --git a/.kokoro/samples/python3.14/periodic.cfg b/.kokoro/samples/python3.14/periodic.cfg new file mode 100644 index 0000000000..71cd1e597e --- /dev/null +++ b/.kokoro/samples/python3.14/periodic.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "False" +} diff --git a/.kokoro/samples/python3.14/presubmit.cfg b/.kokoro/samples/python3.14/presubmit.cfg new file mode 100644 index 0000000000..a1c8d9759c --- /dev/null +++ b/.kokoro/samples/python3.14/presubmit.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.librarian/generator-input/.repo-metadata.json b/.librarian/generator-input/.repo-metadata.json new file mode 100644 index 0000000000..670bbc0e42 --- /dev/null +++ b/.librarian/generator-input/.repo-metadata.json @@ -0,0 +1,18 @@ +{ + "name": "firestore", + "name_pretty": "Cloud Firestore", + "product_documentation": "https://cloud.google.com/firestore", + "client_documentation": "https://cloud.google.com/python/docs/reference/firestore/latest", + "issue_tracker": "https://issuetracker.google.com/savedsearches/5337669", + "release_level": "stable", + "language": "python", + "library_type": "GAPIC_COMBO", + "repo": "googleapis/python-firestore", + "distribution_name": "google-cloud-firestore", + "api_id": "firestore.googleapis.com", + "requires_billing": true, + "default_version": "v1", + "codeowner_team": "@googleapis/api-firestore @googleapis/api-firestore-partners", + "api_shortname": "firestore", + "api_description": "is a fully-managed NoSQL document database for mobile, web, and server development from Firebase and Google Cloud Platform. It's backed by a multi-region replicated database that ensures once data is committed, it's durable even in the face of unexpected disasters. Not only that, but despite being a distributed database, it's also strongly consistent and offers seamless integration with other Firebase and Google Cloud Platform products, including Google Cloud Functions." +} diff --git a/owlbot.py b/.librarian/generator-input/librarian.py similarity index 55% rename from owlbot.py rename to .librarian/generator-input/librarian.py index f08048fef7..ec92a93451 100644 --- a/owlbot.py +++ b/.librarian/generator-input/librarian.py @@ -28,50 +28,10 @@ firestore_default_version = "v1" firestore_admin_default_version = "v1" -# This is a customized version of the s.get_staging_dirs() function from synthtool to -# cater for copying 3 different folders from googleapis-gen -# which are firestore, firestore/admin and firestore/bundle. -# Source https://github.com/googleapis/synthtool/blob/master/synthtool/transforms.py#L280 -def get_staging_dirs( - default_version: Optional[str] = None, sub_directory: Optional[str] = None -) -> List[Path]: - """Returns the list of directories, one per version, copied from - https://github.com/googleapis/googleapis-gen. Will return in lexical sorting - order with the exception of the default_version which will be last (if specified). - - Args: - default_version (str): the default version of the API. The directory for this version - will be the last item in the returned list if specified. - sub_directory (str): if a `sub_directory` is provided, only the directories within the - specified `sub_directory` will be returned. - - Returns: the empty list if no file were copied. - """ - - staging = Path("owl-bot-staging") - - if sub_directory: - staging /= sub_directory - - if staging.is_dir(): - # Collect the subdirectories of the staging directory. - versions = [v.name for v in staging.iterdir() if v.is_dir()] - # Reorder the versions so the default version always comes last. - versions = [v for v in versions if v != default_version] - versions.sort() - if default_version is not None: - versions += [default_version] - dirs = [staging / v for v in versions] - for dir in dirs: - s._tracked_paths.add(dir) - return dirs - else: - return [] - -def update_fixup_scripts(library): +def update_fixup_scripts(path): # Add message for missing 'libcst' dependency s.replace( - library / "scripts/fixup*.py", + library / "scripts" / path, """import libcst as cst""", """try: import libcst as cst @@ -82,19 +42,21 @@ def update_fixup_scripts(library): """, ) -for library in get_staging_dirs(default_version=firestore_default_version, sub_directory="firestore"): - s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "**/gapic_version.py", "noxfile.py"]) +for library in s.get_staging_dirs(default_version=firestore_default_version): + s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) s.move(library / f"tests/", f"tests") - update_fixup_scripts(library) - s.move(library / "scripts") + fixup_script_path = "fixup_firestore_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) -for library in get_staging_dirs(default_version=firestore_admin_default_version, sub_directory="firestore_admin"): - s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "**/gapic_version.py", "noxfile.py"]) +for library in s.get_staging_dirs(default_version=firestore_admin_default_version): + s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) s.move(library / f"tests", f"tests") - update_fixup_scripts(library) - s.move(library / "scripts") + fixup_script_path = "fixup_firestore_admin_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) -for library in get_staging_dirs(sub_directory="firestore_bundle"): +for library in s.get_staging_dirs(): s.replace( library / "google/cloud/bundle/types/bundle.py", "from google.firestore.v1 import document_pb2 # type: ignore\n" @@ -127,7 +89,7 @@ def update_fixup_scripts(library): s.move( library / f"google/cloud/bundle", f"google/cloud/firestore_bundle", - excludes=["**/gapic_version.py", "noxfile.py"], + excludes=["noxfile.py"], ) s.move(library / f"tests", f"tests") @@ -138,27 +100,19 @@ def update_fixup_scripts(library): # ---------------------------------------------------------------------------- templated_files = common.py_library( samples=False, # set to True only if there are samples - system_test_python_versions=["3.7"], unit_test_external_dependencies=["aiounittest", "six", "freezegun"], system_test_external_dependencies=["pytest-asyncio", "six"], microgenerator=True, cov_level=100, split_system_tests=True, + default_python_version="3.14", + system_test_python_versions=["3.14"], + unit_test_python_versions=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"], ) s.move(templated_files, - excludes=[".github/release-please.yml", "renovate.json"]) + excludes=[".github/**", ".kokoro/**", "renovate.json"]) python.py_samples(skip_readmes=True) s.shell.run(["nox", "-s", "blacken"], hide_output=False) - -s.replace( - ".kokoro/build.sh", - "# Setup service account credentials.", - """\ -# Setup firestore account credentials -export FIRESTORE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/firebase-credentials.json - -# Setup service account credentials.""", -) diff --git a/.librarian/generator-input/noxfile.py b/.librarian/generator-input/noxfile.py new file mode 100644 index 0000000000..4fb209cbc4 --- /dev/null +++ b/.librarian/generator-input/noxfile.py @@ -0,0 +1,584 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by synthtool. DO NOT EDIT! + +from __future__ import absolute_import + +import os +import pathlib +import re +import shutil +from typing import Dict, List +import warnings + +import nox + +FLAKE8_VERSION = "flake8==6.1.0" +PYTYPE_VERSION = "pytype==2020.7.24" +BLACK_VERSION = "black[jupyter]==23.7.0" +ISORT_VERSION = "isort==5.11.0" +LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] + +DEFAULT_PYTHON_VERSION = "3.14" + +UNIT_TEST_PYTHON_VERSIONS: List[str] = [ + "3.7", + "3.8", + "3.9", + "3.10", + "3.11", + "3.12", + "3.13", + "3.14", +] +UNIT_TEST_STANDARD_DEPENDENCIES = [ + "mock", + "asyncmock", + "pytest", + "pytest-cov", + "pytest-asyncio==0.21.2", +] +UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ + "aiounittest", + "six", + "freezegun", +] +UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] +UNIT_TEST_DEPENDENCIES: List[str] = [] +UNIT_TEST_EXTRAS: List[str] = [] +UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} + +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.14"] +SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ + "mock", + "pytest", + "google-cloud-testutils", +] +SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ + "pytest-asyncio==0.21.2", + "six", +] +SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_EXTRAS: List[str] = [] +SYSTEM_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() + +nox.options.sessions = [ + "unit-3.9", + "unit-3.10", + "unit-3.11", + "unit-3.12", + "unit-3.13", + "unit-3.14", + "system_emulated", + "system", + "mypy", + "cover", + "lint", + "lint_setup_py", + "blacken", + "docs", + "docfx", + "format", +] + +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install(FLAKE8_VERSION, BLACK_VERSION) + session.run( + "black", + "--check", + *LINT_PATHS, + ) + session.run("flake8", "google", "tests") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + """Run black. Format code to uniform standard.""" + session.install(BLACK_VERSION) + session.run( + "black", + *LINT_PATHS, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def format(session): + """ + Run isort to sort imports. Then run black + to format code to uniform standard. + """ + session.install(BLACK_VERSION, ISORT_VERSION) + # Use the --fss option to sort imports using strict alphabetical order. + # See https://pycqa.github.io/isort/docs/configuration/options.html#force-sort-within-sections + session.run( + "isort", + "--fss", + *LINT_PATHS, + ) + session.run( + "black", + *LINT_PATHS, + ) + + +@nox.session(python="3.7") +def pytype(session): + """Verify type hints are pytype compatible.""" + session.install(PYTYPE_VERSION) + session.run( + "pytype", + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def mypy(session): + """Verify type hints are mypy compatible.""" + session.install("-e", ".") + session.install("mypy", "types-setuptools", "types-protobuf") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1", + "--no-incremental", + "--check-untyped-defs", + "--exclude", + "services", + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint_setup_py(session): + """Verify that setup.py is valid (including RST check).""" + session.install("setuptools", "docutils", "pygments") + session.run("python", "setup.py", "check", "--restructuredtext", "--strict") + + +def install_unittest_dependencies(session, *constraints): + standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES + session.install(*standard_deps, *constraints) + + if UNIT_TEST_EXTERNAL_DEPENDENCIES: + warnings.warn( + "'unit_test_external_dependencies' is deprecated. Instead, please " + "use 'unit_test_dependencies' or 'unit_test_local_dependencies'.", + DeprecationWarning, + ) + session.install(*UNIT_TEST_EXTERNAL_DEPENDENCIES, *constraints) + + if UNIT_TEST_LOCAL_DEPENDENCIES: + session.install(*UNIT_TEST_LOCAL_DEPENDENCIES, *constraints) + + if UNIT_TEST_EXTRAS_BY_PYTHON: + extras = UNIT_TEST_EXTRAS_BY_PYTHON.get(session.python, []) + elif UNIT_TEST_EXTRAS: + extras = UNIT_TEST_EXTRAS + else: + extras = [] + + if extras: + session.install("-e", f".[{','.join(extras)}]", *constraints) + else: + session.install("-e", ".", *constraints) + + +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS) +@nox.parametrize( + "protobuf_implementation", + ["python", "upb", "cpp"], +) +def unit(session, protobuf_implementation): + # Install all test dependencies, then install this package in-place. + + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): + session.skip("cpp implementation is not supported in python 3.11+") + + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + install_unittest_dependencies(session, "-c", constraints_path) + + # TODO(https://github.com/googleapis/synthtool/issues/1976): + # Remove the 'cpp' implementation once support for Protobuf 3.x is dropped. + # The 'cpp' implementation requires Protobuf<4. + if protobuf_implementation == "cpp": + session.install("protobuf<4") + + # Run py.test against the unit tests. + session.run( + "py.test", + "--quiet", + f"--junitxml=unit_{session.python}_sponge_log.xml", + "--cov=google", + "--cov=tests/unit", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + os.path.join("tests", "unit"), + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + + +def install_systemtest_dependencies(session, *constraints): + # Use pre-release gRPC for system tests. + # Exclude version 1.52.0rc1 which has a known issue. + # See https://github.com/grpc/grpc/issues/32163 + session.install("--pre", "grpcio!=1.52.0rc1") + + session.install(*SYSTEM_TEST_STANDARD_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_EXTERNAL_DEPENDENCIES: + session.install(*SYSTEM_TEST_EXTERNAL_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_LOCAL_DEPENDENCIES: + session.install("-e", *SYSTEM_TEST_LOCAL_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_DEPENDENCIES: + session.install("-e", *SYSTEM_TEST_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_EXTRAS_BY_PYTHON: + extras = SYSTEM_TEST_EXTRAS_BY_PYTHON.get(session.python, []) + elif SYSTEM_TEST_EXTRAS: + extras = SYSTEM_TEST_EXTRAS + else: + extras = [] + + if extras: + session.install("-e", f".[{','.join(extras)}]", *constraints) + else: + session.install("-e", ".", *constraints) + + +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system_emulated(session): + import subprocess + import signal + + try: + # https://github.com/googleapis/python-firestore/issues/472 + # Kokoro image doesn't have java installed, don't attempt to run emulator. + subprocess.call(["java", "--version"]) + except OSError: + session.skip("java not found but required for emulator support") + + try: + subprocess.call(["gcloud", "--version"]) + except OSError: + session.skip("gcloud not found but required for emulator support") + + # Currently, CI/CD doesn't have beta component of gcloud. + subprocess.call( + [ + "gcloud", + "components", + "install", + "beta", + "cloud-firestore-emulator", + ] + ) + + hostport = "localhost:8789" + session.env["FIRESTORE_EMULATOR_HOST"] = hostport + + p = subprocess.Popen( + [ + "gcloud", + "--quiet", + "beta", + "emulators", + "firestore", + "start", + "--host-port", + hostport, + ] + ) + + try: + system(session) + finally: + # Stop Emulator + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + + +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system(session): + """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + system_test_path = os.path.join("tests", "system.py") + system_test_folder_path = os.path.join("tests", "system") + + # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. + if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": + session.skip("RUN_SYSTEM_TESTS is set to false, skipping") + # Install pyopenssl for mTLS testing. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": + session.install("pyopenssl") + + system_test_exists = os.path.exists(system_test_path) + system_test_folder_exists = os.path.exists(system_test_folder_path) + # Sanity check: only run tests if found. + if not system_test_exists and not system_test_folder_exists: + session.skip("System tests were not found") + + install_systemtest_dependencies(session, "-c", constraints_path) + + # Run py.test against the system tests. + if system_test_exists: + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + ) + if system_test_folder_exists: + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def cover(session): + """Run the final coverage report. + + This outputs the coverage report aggregating coverage from the unit + test runs (not system test runs), and then erases coverage data. + """ + session.install("coverage", "pytest-cov") + session.run( + "coverage", + "report", + "--show-missing", + "--fail-under=100", + "--omit=tests/*", + ) + + session.run("coverage", "erase") + + +@nox.session(python="3.10") +def docs(session): + """Build the docs for this library.""" + + session.install("-e", ".") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "sphinx==4.5.0", + "alabaster", + "recommonmark", + ) + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-W", # warnings as errors + "-T", # show full traceback on exception + "-N", # no colors + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python="3.10") +def docfx(session): + """Build the docfx yaml files for this library.""" + + session.install("-e", ".") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "gcp-sphinx-docfx-yaml", + "alabaster", + "recommonmark", + ) + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-T", # show full traceback on exception + "-N", # no colors + "-D", + ( + "extensions=sphinx.ext.autodoc," + "sphinx.ext.autosummary," + "docfx_yaml.extension," + "sphinx.ext.intersphinx," + "sphinx.ext.coverage," + "sphinx.ext.napoleon," + "sphinx.ext.todo," + "sphinx.ext.viewcode," + "recommonmark" + ), + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +@nox.parametrize( + "protobuf_implementation", + ["python", "upb", "cpp"], +) +def prerelease_deps(session, protobuf_implementation): + """Run all tests with prerelease versions of dependencies installed.""" + + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): + session.skip("cpp implementation is not supported in python 3.11+") + + # Install all dependencies + session.install("-e", ".[all, tests, tracing]") + unit_deps_all = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_EXTERNAL_DEPENDENCIES + session.install(*unit_deps_all) + system_deps_all = ( + SYSTEM_TEST_STANDARD_DEPENDENCIES + SYSTEM_TEST_EXTERNAL_DEPENDENCIES + ) + session.install(*system_deps_all) + + # Because we test minimum dependency versions on the minimum Python + # version, the first version we test with in the unit tests sessions has a + # constraints file containing all dependencies and extras. + with open( + CURRENT_DIRECTORY + / "testing" + / f"constraints-{UNIT_TEST_PYTHON_VERSIONS[0]}.txt", + encoding="utf-8", + ) as constraints_file: + constraints_text = constraints_file.read() + + # Ignore leading whitespace and comment lines. + constraints_deps = [ + match.group(1) + for match in re.finditer( + r"^\s*(\S+)(?===\S+)", constraints_text, flags=re.MULTILINE + ) + ] + + session.install(*constraints_deps) + + prerel_deps = [ + "protobuf", + # dependency of grpc + "six", + "grpc-google-iam-v1", + "googleapis-common-protos", + "grpcio", + "grpcio-status", + "google-api-core", + "google-auth", + "proto-plus", + "google-cloud-testutils", + # dependencies of google-cloud-testutils" + "click", + ] + + for dep in prerel_deps: + session.install("--pre", "--no-deps", "--upgrade", dep) + + # Remaining dependencies + other_deps = [ + "requests", + ] + session.install(*other_deps) + + # Print out prerelease package versions + session.run( + "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" + ) + session.run("python", "-c", "import grpc; print(grpc.__version__)") + session.run("python", "-c", "import google.auth; print(google.auth.__version__)") + + session.run( + "py.test", + "tests/unit", + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + + system_test_path = os.path.join("tests", "system.py") + system_test_folder_path = os.path.join("tests", "system") + + # Only run system tests if found. + if os.path.exists(system_test_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + if os.path.exists(system_test_folder_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) diff --git a/.librarian/generator-input/setup.py b/.librarian/generator-input/setup.py new file mode 100644 index 0000000000..28d6faf511 --- /dev/null +++ b/.librarian/generator-input/setup.py @@ -0,0 +1,95 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os + +import setuptools + +# Package metadata. + +name = "google-cloud-firestore" +description = "Google Cloud Firestore API client library" + +package_root = os.path.abspath(os.path.dirname(__file__)) + +version = {} +with open(os.path.join(package_root, "google/cloud/firestore/gapic_version.py")) as fp: + exec(fp.read(), version) +version = version["__version__"] +release_status = "Development Status :: 5 - Production/Stable" +dependencies = [ + "google-api-core[grpc] >= 1.34.0, <3.0.0,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", + # Exclude incompatible versions of `google-auth` + # See https://github.com/googleapis/google-cloud-python/issues/12364 + "google-auth >= 2.14.1, <3.0.0,!=2.24.0,!=2.25.0", + "google-cloud-core >= 1.4.1, <3.0.0", + "proto-plus >= 1.22.0, <2.0.0", + "proto-plus >= 1.22.2, <2.0.0; python_version>='3.11'", + "proto-plus >= 1.25.0, <2.0.0; python_version>='3.13'", + "protobuf>=3.20.2,<7.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", +] +extras = {} + + +# Setup boilerplate below this line. + +package_root = os.path.abspath(os.path.dirname(__file__)) +readme_filename = os.path.join(package_root, "README.rst") +with io.open(readme_filename, encoding="utf-8") as readme_file: + readme = readme_file.read() + +# Only include packages under the 'google' namespace. Do not include tests, +# benchmarks, etc. +packages = [ + package + for package in setuptools.find_namespace_packages() + if package.startswith("google") +] + +setuptools.setup( + name=name, + version=version, + description=description, + long_description=readme, + author="Google LLC", + author_email="googleapis-packages@google.com", + license="Apache 2.0", + url="https://github.com/googleapis/python-firestore", + classifiers=[ + release_status, + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Operating System :: OS Independent", + "Topic :: Internet", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + platforms="Posix; MacOS X; Windows", + packages=packages, + install_requires=dependencies, + extras_require=extras, + python_requires=">=3.7", + include_package_data=True, + zip_safe=False, +) diff --git a/.librarian/state.yaml b/.librarian/state.yaml new file mode 100644 index 0000000000..7715edb272 --- /dev/null +++ b/.librarian/state.yaml @@ -0,0 +1,49 @@ +image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator@sha256:b8058df4c45e9a6e07f6b4d65b458d0d059241dd34c814f151c8bf6b89211209 +libraries: + - id: google-cloud-firestore + version: 2.22.0 + last_generated_commit: 1a9d00bed77e6db82ff67764ffe14e3b5209f5cd + apis: + - path: google/firestore/v1 + service_config: firestore_v1.yaml + - path: google/firestore/admin/v1 + service_config: firestore_v1.yaml + - path: google/firestore/bundle + service_config: "" + source_roots: + - . + preserve_regex: [] + remove_regex: + - ^google/cloud/firestore_v1/services + - ^google/cloud/firestore_v1/types + - ^google/cloud/firestore_v1/gapic + - ^google/cloud/firestore_v1/py.typed + - ^google/cloud/firestore_admin_v1/services + - ^google/cloud/firestore_admin_v1/types + - ^google/cloud/firestore_admin_v1/gapic + - ^google/cloud/firestore_admin_v1/py.typed + - ^google/cloud/firestore_bundle/services + - ^google/cloud/firestore_bundle/types + - ^google/cloud/firestore_bundle/__init__.py + - ^google/cloud/firestore_bundle/gapic + - ^google/cloud/firestore_bundle/py.typed + - ^tests/unit/gapic + - ^tests/__init__.py + - ^tests/unit/__init__.py + - ^.pre-commit-config.yaml + - ^.repo-metadata.json + - ^.trampolinerc + - ^.coveragerc + - ^SECURITY.md + - ^noxfile.py + - ^owlbot.py + - ^samples/AUTHORING_GUIDE.md + - ^samples/CONTRIBUTING.md + - ^samples/generated_samples + - ^scripts/fixup_firestore_v1_keywords.py + - ^scripts/fixup_firestore_admin_v1_keywords.py + - ^setup.py + - ^README.rst + - ^docs/README.rst + - ^docs/summary_overview.md + tag_format: v{version} diff --git a/CHANGELOG.md b/CHANGELOG.md index 893a012978..ee59f43a86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,22 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.22.0](https://github.com/googleapis/python-firestore/compare/v2.21.0...v2.22.0) (2025-12-16) + + +### Features + +* support mTLS certificates when available (#1140) ([403afb08109c8271eddd97d6172136271cc0a8a9](https://github.com/googleapis/python-firestore/commit/403afb08109c8271eddd97d6172136271cc0a8a9)) +* Add support for Python 3.14 (#1110) ([52b2055d01ab5d2c34e00f8861e29990f89cd3d8](https://github.com/googleapis/python-firestore/commit/52b2055d01ab5d2c34e00f8861e29990f89cd3d8)) +* Expose tags field in Database and RestoreDatabaseRequest public protos (#1074) ([49836391dc712bd482781a26ccd3c8a8408c473b](https://github.com/googleapis/python-firestore/commit/49836391dc712bd482781a26ccd3c8a8408c473b)) +* Added read_time as a parameter to various calls (synchronous/base classes) (#1050) ([d8e3af1f9dbdfaf5df0d993a0a7e28883472c621](https://github.com/googleapis/python-firestore/commit/d8e3af1f9dbdfaf5df0d993a0a7e28883472c621)) + + +### Bug Fixes + +* improve typing (#1136) ([d1c730d9eef867d16d7818a75f7d58439a942c1d](https://github.com/googleapis/python-firestore/commit/d1c730d9eef867d16d7818a75f7d58439a942c1d)) +* update the async transactional types (#1066) ([210a14a4b758d70aad05940665ed2a2a21ae2a8b](https://github.com/googleapis/python-firestore/commit/210a14a4b758d70aad05940665ed2a2a21ae2a8b)) + ## [2.21.0](https://github.com/googleapis/python-firestore/compare/v2.20.2...v2.21.0) (2025-05-23) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 1d0c00be3e..b592940062 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -22,7 +22,7 @@ In order to add a feature: documentation. - The feature must work fully on the following CPython versions: - 3.7, 3.8, 3.9, 3.10, 3.11, 3.12 and 3.13 on both UNIX and Windows. + 3.7, 3.8, 3.9, 3.10, 3.11, 3.12, 3.13 and 3.14 on both UNIX and Windows. - The feature must not add unnecessary dependencies (where "unnecessary" is of course subjective, but new dependencies should @@ -72,7 +72,7 @@ We use `nox `__ to instrument our tests. - To run a single unit test:: - $ nox -s unit-3.13 -- -k + $ nox -s unit-3.14 -- -k .. note:: @@ -238,6 +238,7 @@ We support: - `Python 3.11`_ - `Python 3.12`_ - `Python 3.13`_ +- `Python 3.14`_ .. _Python 3.7: https://docs.python.org/3.7/ .. _Python 3.8: https://docs.python.org/3.8/ @@ -246,6 +247,7 @@ We support: .. _Python 3.11: https://docs.python.org/3.11/ .. _Python 3.12: https://docs.python.org/3.12/ .. _Python 3.13: https://docs.python.org/3.13/ +.. _Python 3.14: https://docs.python.org/3.14/ Supported versions can be found in our ``noxfile.py`` `config`_. diff --git a/README.rst b/README.rst index e349bf7831..71250f4f72 100644 --- a/README.rst +++ b/README.rst @@ -61,7 +61,7 @@ Supported Python Versions Our client libraries are compatible with all current `active`_ and `maintenance`_ versions of Python. -Python >= 3.7 +Python >= 3.7, including 3.14 .. _active: https://devguide.python.org/devcycle/#in-development-main-branch .. _maintenance: https://devguide.python.org/devcycle/#maintenance-branches diff --git a/docs/README.rst b/docs/README.rst deleted file mode 120000 index 89a0106941..0000000000 --- a/docs/README.rst +++ /dev/null @@ -1 +0,0 @@ -../README.rst \ No newline at end of file diff --git a/docs/README.rst b/docs/README.rst new file mode 100644 index 0000000000..71250f4f72 --- /dev/null +++ b/docs/README.rst @@ -0,0 +1,197 @@ +Python Client for Cloud Firestore API +===================================== + +|stable| |pypi| |versions| + +`Cloud Firestore API`_: is a fully-managed NoSQL document database for mobile, web, and server development from Firebase and Google Cloud Platform. It's backed by a multi-region replicated database that ensures once data is committed, it's durable even in the face of unexpected disasters. Not only that, but despite being a distributed database, it's also strongly consistent and offers seamless integration with other Firebase and Google Cloud Platform products, including Google Cloud Functions. + +- `Client Library Documentation`_ +- `Product Documentation`_ + +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg + :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels +.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-firestore.svg + :target: https://pypi.org/project/google-cloud-firestore/ +.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-firestore.svg + :target: https://pypi.org/project/google-cloud-firestore/ +.. _Cloud Firestore API: https://cloud.google.com/firestore +.. _Client Library Documentation: https://cloud.google.com/python/docs/reference/firestore/latest/summary_overview +.. _Product Documentation: https://cloud.google.com/firestore + +Quick Start +----------- + +In order to use this library, you first need to go through the following steps: + +1. `Select or create a Cloud Platform project.`_ +2. `Enable billing for your project.`_ +3. `Enable the Cloud Firestore API.`_ +4. `Set up Authentication.`_ + +.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project +.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project +.. _Enable the Cloud Firestore API.: https://cloud.google.com/firestore +.. _Set up Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html + +Installation +~~~~~~~~~~~~ + +Install this library in a virtual environment using `venv`_. `venv`_ is a tool that +creates isolated Python environments. These isolated environments can have separate +versions of Python packages, which allows you to isolate one project's dependencies +from the dependencies of other projects. + +With `venv`_, it's possible to install this library without needing system +install permissions, and without clashing with the installed system +dependencies. + +.. _`venv`: https://docs.python.org/3/library/venv.html + + +Code samples and snippets +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Code samples and snippets live in the `samples/`_ folder. + +.. _samples/: https://github.com/googleapis/python-firestore/tree/main/samples + + +Supported Python Versions +^^^^^^^^^^^^^^^^^^^^^^^^^ +Our client libraries are compatible with all current `active`_ and `maintenance`_ versions of +Python. + +Python >= 3.7, including 3.14 + +.. _active: https://devguide.python.org/devcycle/#in-development-main-branch +.. _maintenance: https://devguide.python.org/devcycle/#maintenance-branches + +Unsupported Python Versions +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Python <= 3.6 + +If you are using an `end-of-life`_ +version of Python, we recommend that you update as soon as possible to an actively supported version. + +.. _end-of-life: https://devguide.python.org/devcycle/#end-of-life-branches + +Mac/Linux +^^^^^^^^^ + +.. code-block:: console + + python3 -m venv + source /bin/activate + pip install google-cloud-firestore + + +Windows +^^^^^^^ + +.. code-block:: console + + py -m venv + .\\Scripts\activate + pip install google-cloud-firestore + +Next Steps +~~~~~~~~~~ + +- Read the `Client Library Documentation`_ for Cloud Firestore API + to see other available methods on the client. +- Read the `Cloud Firestore API Product documentation`_ to learn + more about the product and see How-to Guides. +- View this `README`_ to see the full list of Cloud + APIs that we cover. + +.. _Cloud Firestore API Product documentation: https://cloud.google.com/firestore +.. _README: https://github.com/googleapis/google-cloud-python/blob/main/README.rst + +Logging +------- + +This library uses the standard Python :code:`logging` functionality to log some RPC events that could be of interest for debugging and monitoring purposes. +Note the following: + +#. Logs may contain sensitive information. Take care to **restrict access to the logs** if they are saved, whether it be on local storage or on Google Cloud Logging. +#. Google may refine the occurrence, level, and content of various log messages in this library without flagging such changes as breaking. **Do not depend on immutability of the logging events**. +#. By default, the logging events from this library are not handled. You must **explicitly configure log handling** using one of the mechanisms below. + +Simple, environment-based configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable logging for this library without any changes in your code, set the :code:`GOOGLE_SDK_PYTHON_LOGGING_SCOPE` environment variable to a valid Google +logging scope. This configures handling of logging events (at level :code:`logging.DEBUG` or higher) from this library in a default manner, emitting the logged +messages in a structured format. It does not currently allow customizing the logging levels captured nor the handlers, formatters, etc. used for any logging +event. + +A logging scope is a period-separated namespace that begins with :code:`google`, identifying the Python module or package to log. + +- Valid logging scopes: :code:`google`, :code:`google.cloud.asset.v1`, :code:`google.api`, :code:`google.auth`, etc. +- Invalid logging scopes: :code:`foo`, :code:`123`, etc. + +**NOTE**: If the logging scope is invalid, the library does not set up any logging handlers. + +Environment-Based Examples +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Enabling the default handler for all Google-based loggers + +.. code-block:: console + + export GOOGLE_SDK_PYTHON_LOGGING_SCOPE=google + +- Enabling the default handler for a specific Google module (for a client library called :code:`library_v1`): + +.. code-block:: console + + export GOOGLE_SDK_PYTHON_LOGGING_SCOPE=google.cloud.library_v1 + + +Advanced, code-based configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also configure a valid logging scope using Python's standard `logging` mechanism. + +Code-Based Examples +^^^^^^^^^^^^^^^^^^^ + +- Configuring a handler for all Google-based loggers + +.. code-block:: python + + import logging + + from google.cloud import library_v1 + + base_logger = logging.getLogger("google") + base_logger.addHandler(logging.StreamHandler()) + base_logger.setLevel(logging.DEBUG) + +- Configuring a handler for a specific Google module (for a client library called :code:`library_v1`): + +.. code-block:: python + + import logging + + from google.cloud import library_v1 + + base_logger = logging.getLogger("google.cloud.library_v1") + base_logger.addHandler(logging.StreamHandler()) + base_logger.setLevel(logging.DEBUG) + +Logging details +~~~~~~~~~~~~~~~ + +#. Regardless of which of the mechanisms above you use to configure logging for this library, by default logging events are not propagated up to the root + logger from the `google`-level logger. If you need the events to be propagated to the root logger, you must explicitly set + :code:`logging.getLogger("google").propagate = True` in your code. +#. You can mix the different logging configurations above for different Google modules. For example, you may want use a code-based logging configuration for + one library, but decide you need to also set up environment-based logging configuration for another library. + + #. If you attempt to use both code-based and environment-based configuration for the same module, the environment-based configuration will be ineffectual + if the code -based configuration gets applied first. + +#. The Google-specific logging configurations (default handlers for environment-based configuration; not propagating logging events to the root logger) get + executed the first time *any* client library is instantiated in your application, and only if the affected loggers have not been previously configured. + (This is the reason for 2.i. above.) diff --git a/google/cloud/firestore/gapic_version.py b/google/cloud/firestore/gapic_version.py index e546bae053..03d6d0200b 100644 --- a/google/cloud/firestore/gapic_version.py +++ b/google/cloud/firestore/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/gapic_metadata.json b/google/cloud/firestore_admin_v1/gapic_metadata.json index e2c91bdb59..b8d4cb298c 100644 --- a/google/cloud/firestore_admin_v1/gapic_metadata.json +++ b/google/cloud/firestore_admin_v1/gapic_metadata.json @@ -15,6 +15,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" @@ -175,6 +180,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" @@ -335,6 +345,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" diff --git a/google/cloud/firestore_admin_v1/gapic_version.py b/google/cloud/firestore_admin_v1/gapic_version.py index e546bae053..ced4e0faf0 100644 --- a/google/cloud/firestore_admin_v1/gapic_version.py +++ b/google/cloud/firestore_admin_v1/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py index 56531fa29a..a2800e34ea 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py @@ -4111,6 +4111,143 @@ async def sample_delete_backup_schedule(): metadata=metadata, ) + async def clone_database( + self, + request: Optional[Union[firestore_admin.CloneDatabaseRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_admin_v1 + + async def sample_clone_database(): + # Create a client + client = firestore_admin_v1.FirestoreAdminAsyncClient() + + # Initialize request argument(s) + pitr_snapshot = firestore_admin_v1.PitrSnapshot() + pitr_snapshot.database = "database_value" + + request = firestore_admin_v1.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + pitr_snapshot=pitr_snapshot, + ) + + # Make the request + operation = client.clone_database(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_admin_v1.types.CloneDatabaseRequest, dict]]): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.firestore_admin_v1.types.Database` + A Cloud Firestore Database. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore_admin.CloneDatabaseRequest): + request = firestore_admin.CloneDatabaseRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.clone_database + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + database.Database, + metadata_type=gfa_operation.CloneDatabaseMetadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py index d05b82787d..b55c157cf2 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py @@ -198,6 +198,34 @@ def _get_default_mtls_endpoint(api_endpoint): _DEFAULT_ENDPOINT_TEMPLATE = "firestore.{UNIVERSE_DOMAIN}" _DEFAULT_UNIVERSE = "googleapis.com" + @staticmethod + def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS if the + google-auth version supports should_use_client_cert automatic mTLS enablement. + + Alternatively, read from the GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS + Raises: + ValueError: (If using a version of google-auth without should_use_client_cert and + GOOGLE_API_USE_CLIENT_CERTIFICATE is set to an unexpected value.) + """ + # check if google-auth version supports should_use_client_cert for automatic mTLS enablement + if hasattr(mtls, "should_use_client_cert"): # pragma: NO COVER + return mtls.should_use_client_cert() + else: # pragma: NO COVER + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -555,12 +583,8 @@ def get_mtls_endpoint_and_cert_source( ) if client_options is None: client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_client_cert = FirestoreAdminClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" @@ -568,7 +592,7 @@ def get_mtls_endpoint_and_cert_source( # Figure out the client cert source to use. client_cert_source = None - if use_client_cert == "true": + if use_client_cert: if client_options.client_cert_source: client_cert_source = client_options.client_cert_source elif mtls.has_default_client_cert_source(): @@ -600,20 +624,14 @@ def _read_environment_variables(): google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT is not any of ["auto", "never", "always"]. """ - use_client_cert = os.getenv( - "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" - ).lower() + use_client_cert = FirestoreAdminClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + return use_client_cert, use_mtls_endpoint, universe_domain_env @staticmethod def _get_client_cert_source(provided_cert_source, use_cert_flag): @@ -4591,6 +4609,141 @@ def sample_delete_backup_schedule(): metadata=metadata, ) + def clone_database( + self, + request: Optional[Union[firestore_admin.CloneDatabaseRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gac_operation.Operation: + r"""Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_admin_v1 + + def sample_clone_database(): + # Create a client + client = firestore_admin_v1.FirestoreAdminClient() + + # Initialize request argument(s) + pitr_snapshot = firestore_admin_v1.PitrSnapshot() + pitr_snapshot.database = "database_value" + + request = firestore_admin_v1.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + pitr_snapshot=pitr_snapshot, + ) + + # Make the request + operation = client.clone_database(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.firestore_admin_v1.types.CloneDatabaseRequest, dict]): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.firestore_admin_v1.types.Database` + A Cloud Firestore Database. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore_admin.CloneDatabaseRequest): + request = firestore_admin.CloneDatabaseRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.clone_database] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + database.Database, + metadata_type=gfa_operation.CloneDatabaseMetadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "FirestoreAdminClient": return self diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py index f290fcbfe1..7d582d9b5a 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py @@ -81,9 +81,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. @@ -357,6 +358,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.clone_database: gapic_v1.method.wrap_method( + self.clone_database, + default_timeout=120.0, + client_info=client_info, + ), self.cancel_operation: gapic_v1.method.wrap_method( self.cancel_operation, default_timeout=None, @@ -688,6 +694,15 @@ def delete_backup_schedule( ]: raise NotImplementedError() + @property + def clone_database( + self, + ) -> Callable[ + [firestore_admin.CloneDatabaseRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py index c6e7824c23..f6531a1906 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py @@ -192,9 +192,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if a ``channel`` instance is provided. channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): @@ -328,9 +329,10 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -1279,6 +1281,50 @@ def delete_backup_schedule( ) return self._stubs["delete_backup_schedule"] + @property + def clone_database( + self, + ) -> Callable[[firestore_admin.CloneDatabaseRequest], operations_pb2.Operation]: + r"""Return a callable for the clone database method over gRPC. + + Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + Returns: + Callable[[~.CloneDatabaseRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "clone_database" not in self._stubs: + self._stubs["clone_database"] = self._logged_channel.unary_unary( + "/google.firestore.admin.v1.FirestoreAdmin/CloneDatabase", + request_serializer=firestore_admin.CloneDatabaseRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["clone_database"] + def close(self): self._logged_channel.close() diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py index 9dd9d61556..117707853c 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py @@ -189,8 +189,9 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. + credentials_file (Optional[str]): Deprecated. A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -241,9 +242,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -1331,6 +1333,52 @@ def delete_backup_schedule( ) return self._stubs["delete_backup_schedule"] + @property + def clone_database( + self, + ) -> Callable[ + [firestore_admin.CloneDatabaseRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the clone database method over gRPC. + + Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + Returns: + Callable[[~.CloneDatabaseRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "clone_database" not in self._stubs: + self._stubs["clone_database"] = self._logged_channel.unary_unary( + "/google.firestore.admin.v1.FirestoreAdmin/CloneDatabase", + request_serializer=firestore_admin.CloneDatabaseRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["clone_database"] + def _prep_wrapped_messages(self, client_info): """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" self._wrapped_methods = { @@ -1544,6 +1592,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.clone_database: self._wrap_method( + self.clone_database, + default_timeout=120.0, + client_info=client_info, + ), self.cancel_operation: self._wrap_method( self.cancel_operation, default_timeout=None, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py index c96be2e329..41e819c875 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py @@ -97,6 +97,14 @@ def post_bulk_delete_documents(self, response): logging.log(f"Received response: {response}") return response + def pre_clone_database(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_clone_database(self, response): + logging.log(f"Received response: {response}") + return response + def pre_create_backup_schedule(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -376,6 +384,54 @@ def post_bulk_delete_documents_with_metadata( """ return response, metadata + def pre_clone_database( + self, + request: firestore_admin.CloneDatabaseRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore_admin.CloneDatabaseRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for clone_database + + Override in a subclass to manipulate the request or metadata + before they are sent to the FirestoreAdmin server. + """ + return request, metadata + + def post_clone_database( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for clone_database + + DEPRECATED. Please use the `post_clone_database_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the FirestoreAdmin server but before + it is returned to user code. This `post_clone_database` interceptor runs + before the `post_clone_database_with_metadata` interceptor. + """ + return response + + def post_clone_database_with_metadata( + self, + response: operations_pb2.Operation, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[operations_pb2.Operation, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for clone_database + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the FirestoreAdmin server but before it is returned to user code. + + We recommend only using this `post_clone_database_with_metadata` + interceptor in new development instead of the `post_clone_database` interceptor. + When both interceptors are used, this `post_clone_database_with_metadata` interceptor runs after the + `post_clone_database` interceptor. The (possibly modified) response returned by + `post_clone_database` will be passed to + `post_clone_database_with_metadata`. + """ + return response, metadata + def pre_create_backup_schedule( self, request: firestore_admin.CreateBackupScheduleRequest, @@ -1857,9 +1913,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if ``channel`` is provided. This argument will be + removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client @@ -2115,6 +2172,158 @@ def __call__( ) return resp + class _CloneDatabase( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase, FirestoreAdminRestStub + ): + def __hash__(self): + return hash("FirestoreAdminRestTransport.CloneDatabase") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: firestore_admin.CloneDatabaseRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the clone database method over HTTP. + + Args: + request (~.firestore_admin.CloneDatabaseRequest): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options = ( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_http_options() + ) + + request, metadata = self._interceptor.pre_clone_database(request, metadata) + transcoded_request = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_transcoded_request( + http_options, request + ) + + body = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore.admin_v1.FirestoreAdminClient.CloneDatabase", + extra={ + "serviceName": "google.firestore.admin.v1.FirestoreAdmin", + "rpcName": "CloneDatabase", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreAdminRestTransport._CloneDatabase._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_clone_database(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_clone_database_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore.admin_v1.FirestoreAdminClient.clone_database", + extra={ + "serviceName": "google.firestore.admin.v1.FirestoreAdmin", + "rpcName": "CloneDatabase", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _CreateBackupSchedule( _BaseFirestoreAdminRestTransport._BaseCreateBackupSchedule, FirestoreAdminRestStub, @@ -6507,6 +6716,14 @@ def bulk_delete_documents( # In C++ this would require a dynamic_cast return self._BulkDeleteDocuments(self._session, self._host, self._interceptor) # type: ignore + @property + def clone_database( + self, + ) -> Callable[[firestore_admin.CloneDatabaseRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CloneDatabase(self._session, self._host, self._interceptor) # type: ignore + @property def create_backup_schedule( self, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py index 19a0c9856f..56b6ce93f8 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py @@ -156,6 +156,63 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseCloneDatabase: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*}/databases:clone", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore_admin.CloneDatabaseRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseCreateBackupSchedule: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/firestore_admin_v1/types/__init__.py b/google/cloud/firestore_admin_v1/types/__init__.py index 249147d52a..c76372e5d5 100644 --- a/google/cloud/firestore_admin_v1/types/__init__.py +++ b/google/cloud/firestore_admin_v1/types/__init__.py @@ -25,6 +25,7 @@ from .firestore_admin import ( BulkDeleteDocumentsRequest, BulkDeleteDocumentsResponse, + CloneDatabaseRequest, CreateBackupScheduleRequest, CreateDatabaseMetadata, CreateDatabaseRequest, @@ -73,6 +74,7 @@ ) from .operation import ( BulkDeleteDocumentsMetadata, + CloneDatabaseMetadata, ExportDocumentsMetadata, ExportDocumentsResponse, FieldOperationMetadata, @@ -87,6 +89,9 @@ DailyRecurrence, WeeklyRecurrence, ) +from .snapshot import ( + PitrSnapshot, +) from .user_creds import ( UserCreds, ) @@ -97,6 +102,7 @@ "Field", "BulkDeleteDocumentsRequest", "BulkDeleteDocumentsResponse", + "CloneDatabaseRequest", "CreateBackupScheduleRequest", "CreateDatabaseMetadata", "CreateDatabaseRequest", @@ -139,6 +145,7 @@ "Index", "LocationMetadata", "BulkDeleteDocumentsMetadata", + "CloneDatabaseMetadata", "ExportDocumentsMetadata", "ExportDocumentsResponse", "FieldOperationMetadata", @@ -150,5 +157,6 @@ "BackupSchedule", "DailyRecurrence", "WeeklyRecurrence", + "PitrSnapshot", "UserCreds", ) diff --git a/google/cloud/firestore_admin_v1/types/database.py b/google/cloud/firestore_admin_v1/types/database.py index 4f985a6515..f46bede62b 100644 --- a/google/cloud/firestore_admin_v1/types/database.py +++ b/google/cloud/firestore_admin_v1/types/database.py @@ -119,6 +119,13 @@ class Database(proto.Message): source_info (google.cloud.firestore_admin_v1.types.Database.SourceInfo): Output only. Information about the provenance of this database. + tags (MutableMapping[str, str]): + Optional. Input only. Immutable. Tag + keys/values directly bound to this resource. For + example: + + "123/environment": "production", + "123/costCenter": "marketing". free_tier (bool): Output only. Background: Free tier is the ability of a Firestore database to use a small @@ -206,9 +213,9 @@ class PointInTimeRecoveryEnablement(proto.Enum): Reads are supported on selected versions of the data from within the past 7 days: - - Reads against any timestamp within the past hour - - Reads against 1-minute snapshots beyond 1 hour and within - 7 days + - Reads against any timestamp within the past hour + - Reads against 1-minute snapshots beyond 1 hour and within + 7 days ``version_retention_period`` and ``earliest_version_time`` can be used to determine the supported versions. @@ -524,6 +531,11 @@ class CustomerManagedEncryptionOptions(proto.Message): number=26, message=SourceInfo, ) + tags: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=29, + ) free_tier: bool = proto.Field( proto.BOOL, number=30, diff --git a/google/cloud/firestore_admin_v1/types/firestore_admin.py b/google/cloud/firestore_admin_v1/types/firestore_admin.py index 77d78cb355..9ede35cacf 100644 --- a/google/cloud/firestore_admin_v1/types/firestore_admin.py +++ b/google/cloud/firestore_admin_v1/types/firestore_admin.py @@ -24,6 +24,7 @@ from google.cloud.firestore_admin_v1.types import field as gfa_field from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import schedule +from google.cloud.firestore_admin_v1.types import snapshot from google.cloud.firestore_admin_v1.types import user_creds as gfa_user_creds from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -73,6 +74,7 @@ "ListBackupsResponse", "DeleteBackupRequest", "RestoreDatabaseRequest", + "CloneDatabaseRequest", }, ) @@ -951,7 +953,7 @@ class ListBackupsRequest(proto.Message): [Backup][google.firestore.admin.v1.Backup] are eligible for filtering: - - ``database_uid`` (supports ``=`` only) + - ``database_uid`` (supports ``=`` only) """ parent: str = proto.Field( @@ -1047,6 +1049,12 @@ class RestoreDatabaseRequest(proto.Message): If this field is not specified, the restored database will use the same encryption configuration as the backup, namely [use_source_encryption][google.firestore.admin.v1.Database.EncryptionConfig.use_source_encryption]. + tags (MutableMapping[str, str]): + Optional. Immutable. Tags to be bound to the restored + database. + + The tags should be provided in the format of + ``tagKeys/{tag_key_id} -> tagValues/{tag_value_id}``. """ parent: str = proto.Field( @@ -1066,6 +1074,77 @@ class RestoreDatabaseRequest(proto.Message): number=9, message=gfa_database.Database.EncryptionConfig, ) + tags: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=10, + ) + + +class CloneDatabaseRequest(proto.Message): + r"""The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + + Attributes: + parent (str): + Required. The project to clone the database in. Format is + ``projects/{project_id}``. + database_id (str): + Required. The ID to use for the database, which will become + the final component of the database's resource name. This + database ID must not be associated with an existing + database. + + This value should be 4-63 characters. Valid characters are + /[a-z][0-9]-/ with first character a letter and the last a + letter or a number. Must not be UUID-like + /[0-9a-f]{8}(-[0-9a-f]{4}){3}-[0-9a-f]{12}/. + + "(default)" database ID is also valid. + pitr_snapshot (google.cloud.firestore_admin_v1.types.PitrSnapshot): + Required. Specification of the PITR data to + clone from. The source database must exist. + + The cloned database will be created in the same + location as the source database. + encryption_config (google.cloud.firestore_admin_v1.types.Database.EncryptionConfig): + Optional. Encryption configuration for the cloned database. + + If this field is not specified, the cloned database will use + the same encryption configuration as the source database, + namely + [use_source_encryption][google.firestore.admin.v1.Database.EncryptionConfig.use_source_encryption]. + tags (MutableMapping[str, str]): + Optional. Immutable. Tags to be bound to the cloned + database. + + The tags should be provided in the format of + ``tagKeys/{tag_key_id} -> tagValues/{tag_value_id}``. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + database_id: str = proto.Field( + proto.STRING, + number=2, + ) + pitr_snapshot: snapshot.PitrSnapshot = proto.Field( + proto.MESSAGE, + number=6, + message=snapshot.PitrSnapshot, + ) + encryption_config: gfa_database.Database.EncryptionConfig = proto.Field( + proto.MESSAGE, + number=4, + message=gfa_database.Database.EncryptionConfig, + ) + tags: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=5, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_admin_v1/types/operation.py b/google/cloud/firestore_admin_v1/types/operation.py index c58f242733..c504556933 100644 --- a/google/cloud/firestore_admin_v1/types/operation.py +++ b/google/cloud/firestore_admin_v1/types/operation.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.firestore_admin_v1.types import index as gfa_index +from google.cloud.firestore_admin_v1.types import snapshot from google.protobuf import timestamp_pb2 # type: ignore @@ -34,6 +35,7 @@ "BulkDeleteDocumentsMetadata", "ExportDocumentsResponse", "RestoreDatabaseMetadata", + "CloneDatabaseMetadata", "Progress", }, ) @@ -558,6 +560,60 @@ class RestoreDatabaseMetadata(proto.Message): ) +class CloneDatabaseMetadata(proto.Message): + r"""Metadata for the [long-running + operation][google.longrunning.Operation] from the + [CloneDatabase][google.firestore.admin.v1.CloneDatabase] request. + + Attributes: + start_time (google.protobuf.timestamp_pb2.Timestamp): + The time the clone was started. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The time the clone finished, unset for + ongoing clones. + operation_state (google.cloud.firestore_admin_v1.types.OperationState): + The operation state of the clone. + database (str): + The name of the database being cloned to. + pitr_snapshot (google.cloud.firestore_admin_v1.types.PitrSnapshot): + The snapshot from which this database was + cloned. + progress_percentage (google.cloud.firestore_admin_v1.types.Progress): + How far along the clone is as an estimated + percentage of remaining time. + """ + + start_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + end_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=2, + message=timestamp_pb2.Timestamp, + ) + operation_state: "OperationState" = proto.Field( + proto.ENUM, + number=3, + enum="OperationState", + ) + database: str = proto.Field( + proto.STRING, + number=4, + ) + pitr_snapshot: snapshot.PitrSnapshot = proto.Field( + proto.MESSAGE, + number=7, + message=snapshot.PitrSnapshot, + ) + progress_percentage: "Progress" = proto.Field( + proto.MESSAGE, + number=6, + message="Progress", + ) + + class Progress(proto.Message): r"""Describes the progress of the operation. Unit of work is generic and must be interpreted based on where diff --git a/google/cloud/firestore_admin_v1/types/snapshot.py b/google/cloud/firestore_admin_v1/types/snapshot.py new file mode 100644 index 0000000000..e56a125f59 --- /dev/null +++ b/google/cloud/firestore_admin_v1/types/snapshot.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import timestamp_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.admin.v1", + manifest={ + "PitrSnapshot", + }, +) + + +class PitrSnapshot(proto.Message): + r"""A consistent snapshot of a database at a specific point in + time. A PITR (Point-in-time recovery) snapshot with previous + versions of a database's data is available for every minute up + to the associated database's data retention period. If the PITR + feature is enabled, the retention period is 7 days; otherwise, + it is one hour. + + Attributes: + database (str): + Required. The name of the database that this was a snapshot + of. Format: ``projects/{project}/databases/{database}``. + database_uid (bytes): + Output only. Public UUID of the database the + snapshot was associated with. + snapshot_time (google.protobuf.timestamp_pb2.Timestamp): + Required. Snapshot time of the database. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + database_uid: bytes = proto.Field( + proto.BYTES, + number=2, + ) + snapshot_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_bundle/__init__.py b/google/cloud/firestore_bundle/__init__.py index 1b6469437b..30faafe58a 100644 --- a/google/cloud/firestore_bundle/__init__.py +++ b/google/cloud/firestore_bundle/__init__.py @@ -15,8 +15,18 @@ # from google.cloud.firestore_bundle import gapic_version as package_version +import google.api_core as api_core +import sys + __version__ = package_version.__version__ +if sys.version_info >= (3, 8): # pragma: NO COVER + from importlib import metadata +else: # pragma: NO COVER + # TODO(https://github.com/googleapis/python-api-core/issues/835): Remove + # this code path once we drop support for Python 3.7 + import importlib_metadata as metadata + from .types.bundle import BundledDocumentMetadata from .types.bundle import BundledQuery @@ -26,6 +36,100 @@ from .bundle import FirestoreBundle +if hasattr(api_core, "check_python_version") and hasattr( + api_core, "check_dependency_versions" +): # pragma: NO COVER + api_core.check_python_version("google.cloud.bundle") # type: ignore + api_core.check_dependency_versions("google.cloud.bundle") # type: ignore +else: # pragma: NO COVER + # An older version of api_core is installed which does not define the + # functions above. We do equivalent checks manually. + try: + import warnings + import sys + + _py_version_str = sys.version.split()[0] + _package_label = "google.cloud.bundle" + if sys.version_info < (3, 9): + warnings.warn( + "You are using a non-supported Python version " + + f"({_py_version_str}). Google will not post any further " + + f"updates to {_package_label} supporting this Python version. " + + "Please upgrade to the latest Python version, or at " + + f"least to Python 3.9, and then update {_package_label}.", + FutureWarning, + ) + if sys.version_info[:2] == (3, 9): + warnings.warn( + f"You are using a Python version ({_py_version_str}) " + + f"which Google will stop supporting in {_package_label} in " + + "January 2026. Please " + + "upgrade to the latest Python version, or at " + + "least to Python 3.10, before then, and " + + f"then update {_package_label}.", + FutureWarning, + ) + + def parse_version_to_tuple(version_string: str): + """Safely converts a semantic version string to a comparable tuple of integers. + Example: "4.25.8" -> (4, 25, 8) + Ignores non-numeric parts and handles common version formats. + Args: + version_string: Version string in the format "x.y.z" or "x.y.z" + Returns: + Tuple of integers for the parsed version string. + """ + parts = [] + for part in version_string.split("."): + try: + parts.append(int(part)) + except ValueError: + # If it's a non-numeric part (e.g., '1.0.0b1' -> 'b1'), stop here. + # This is a simplification compared to 'packaging.parse_version', but sufficient + # for comparing strictly numeric semantic versions. + break + return tuple(parts) + + def _get_version(dependency_name): + try: + version_string: str = metadata.version(dependency_name) + parsed_version = parse_version_to_tuple(version_string) + return (parsed_version, version_string) + except Exception: + # Catch exceptions from metadata.version() (e.g., PackageNotFoundError) + # or errors during parse_version_to_tuple + return (None, "--") + + _dependency_package = "google.protobuf" + _next_supported_version = "4.25.8" + _next_supported_version_tuple = (4, 25, 8) + _recommendation = " (we recommend 6.x)" + (_version_used, _version_used_string) = _get_version(_dependency_package) + if _version_used and _version_used < _next_supported_version_tuple: + warnings.warn( + f"Package {_package_label} depends on " + + f"{_dependency_package}, currently installed at version " + + f"{_version_used_string}. Future updates to " + + f"{_package_label} will require {_dependency_package} at " + + f"version {_next_supported_version} or higher{_recommendation}." + + " Please ensure " + + "that either (a) your Python environment doesn't pin the " + + f"version of {_dependency_package}, so that updates to " + + f"{_package_label} can require the higher version, or " + + "(b) you manually update your Python environment to use at " + + f"least version {_next_supported_version} of " + + f"{_dependency_package}.", + FutureWarning, + ) + except Exception: + warnings.warn( + "Could not determine the version of Python " + + "currently being used. To continue receiving " + + "updates for {_package_label}, ensure you are " + + "using a supported version of Python; see " + + "https://devguide.python.org/versions/" + ) + __all__ = ( "BundleElement", "BundleMetadata", diff --git a/google/cloud/firestore_bundle/bundle.py b/google/cloud/firestore_bundle/bundle.py index 0f9aaed976..e985a1e065 100644 --- a/google/cloud/firestore_bundle/bundle.py +++ b/google/cloud/firestore_bundle/bundle.py @@ -344,9 +344,10 @@ def build(self) -> str: BundleElement(document_metadata=bundled_document.metadata) ) document_count += 1 + bundle_pb = bundled_document.snapshot._to_protobuf() buffer += self._compile_bundle_element( BundleElement( - document=bundled_document.snapshot._to_protobuf()._pb, + document=bundle_pb._pb if bundle_pb else None, ) ) diff --git a/google/cloud/firestore_bundle/gapic_version.py b/google/cloud/firestore_bundle/gapic_version.py index e546bae053..ced4e0faf0 100644 --- a/google/cloud/firestore_bundle/gapic_version.py +++ b/google/cloud/firestore_bundle/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index ec0fbc1894..69c4dc6bd7 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -39,6 +39,8 @@ from google.cloud.firestore_v1.query_profile import ExplainMetrics from google.cloud.firestore_v1.query_profile import ExplainOptions + import datetime + class AggregationQuery(BaseAggregationQuery): """Represents an aggregation query to the Firestore API.""" @@ -56,6 +58,7 @@ def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[AggregationResult]: """Runs the aggregation query. @@ -64,8 +67,7 @@ def get( messages. Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -74,10 +76,13 @@ def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: QueryResultsList[AggregationResult]: The aggregation query results. @@ -90,6 +95,7 @@ def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) result_list = list(result) @@ -100,13 +106,16 @@ def get( return QueryResultsList(result_list, explain_options, explain_metrics) - def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None): + def _get_stream_iterator( + self, transaction, retry, timeout, explain_options=None, read_time=None + ): """Helper method for :meth:`stream`.""" request, kwargs = self._prep_stream( transaction, retry, timeout, explain_options, + read_time, ) return self._client._firestore_api.run_aggregation_query( @@ -132,6 +141,7 @@ def _make_stream( retry: Union[retries.Retry, None, object] = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> Generator[List[AggregationResult], Any, Optional[ExplainMetrics]]: """Internal method for stream(). Runs the aggregation query. @@ -143,18 +153,20 @@ def _make_stream( this method cannot be used (i.e. read-after-write is not allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. retry (Optional[google.api_core.retry.Retry]): Designation of what errors, if any, should be retried. Defaults to a system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: List[AggregationResult]: @@ -172,6 +184,7 @@ def _make_stream( retry, timeout, explain_options, + read_time, ) while True: try: @@ -182,6 +195,8 @@ def _make_stream( transaction, retry, timeout, + explain_options, + read_time, ) continue else: @@ -206,6 +221,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> StreamGenerator[List[AggregationResult]]: """Runs the aggregation query. @@ -217,18 +233,20 @@ def stream( this method cannot be used (i.e. read-after-write is not allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. retry (Optional[google.api_core.retry.Retry]): Designation of what errors, if any, should be retried. Defaults to a system-specified policy. timeout (Optinal[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `StreamGenerator[List[AggregationResult]]`: @@ -239,5 +257,6 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) return StreamGenerator(inner_generator, explain_options) diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index 3f3a1b9f43..5825a06d81 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -37,6 +37,7 @@ from google.cloud.firestore_v1.base_aggregation import AggregationResult from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions import google.cloud.firestore_v1.types.query_profile as query_profile_pb + import datetime class AsyncAggregationQuery(BaseAggregationQuery): @@ -55,14 +56,14 @@ async def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[List[AggregationResult]]: """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -71,10 +72,13 @@ async def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: QueryResultsList[List[AggregationResult]]: The aggregation query results. @@ -87,6 +91,7 @@ async def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) try: result = [aggregation async for aggregation in stream_result] @@ -106,6 +111,7 @@ async def _make_stream( retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[List[AggregationResult] | query_profile_pb.ExplainMetrics, Any]: """Internal method for stream(). Runs the aggregation query. @@ -126,10 +132,13 @@ async def _make_stream( system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: List[AggregationResult] | query_profile_pb.ExplainMetrics: @@ -143,6 +152,7 @@ async def _make_stream( retry, timeout, explain_options, + read_time, ) response_iterator = await self._client._firestore_api.run_aggregation_query( @@ -167,6 +177,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[List[AggregationResult]]: """Runs the aggregation query. @@ -186,10 +197,13 @@ def stream( system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `AsyncStreamGenerator[List[AggregationResult]]`: @@ -201,5 +215,6 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) return AsyncStreamGenerator(inner_generator, explain_options) diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 689753fe9f..f74ccacea9 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -19,6 +19,7 @@ from google.api_core import retry_async as retries from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.types.write import WriteResult class AsyncWriteBatch(BaseWriteBatch): @@ -40,7 +41,7 @@ async def commit( self, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> list: + ) -> list[WriteResult]: """Commit the changes accumulated in this batch. Args: diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 275bcb9b61..fd016dfe7e 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,7 +25,15 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Iterable, + List, + Optional, + Union, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -40,6 +48,7 @@ from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore from google.cloud.firestore_v1.base_client import _CLIENT_INFO, BaseClient, _path_helper +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, @@ -48,8 +57,10 @@ grpc_asyncio as firestore_grpc_transport, ) -if TYPE_CHECKING: - from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER +if TYPE_CHECKING: # pragma: NO COVER + import datetime + + from google.cloud.firestore_v1.bulk_writer import BulkWriter class AsyncClient(BaseClient): @@ -227,6 +238,8 @@ async def get_all( transaction: AsyncTransaction | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieve a batch of documents. @@ -261,13 +274,17 @@ async def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ request, reference_map, kwargs = self._prep_get_all( - references, field_paths, transaction, retry, timeout + references, field_paths, transaction, retry, timeout, read_time ) response_iterator = await self._firestore_api.batch_get_documents( @@ -283,6 +300,8 @@ async def collections( self, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[AsyncCollectionReference, Any]: """List top-level collections of the client's database. @@ -291,12 +310,16 @@ async def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: iterator of subcollections of the current document. """ - request, kwargs = self._prep_collections(retry, timeout) + request, kwargs = self._prep_collections(retry, timeout, read_time) iterator = await self._firestore_api.list_collection_ids( request=request, metadata=self._rpc_metadata, @@ -396,7 +419,9 @@ def batch(self) -> AsyncWriteBatch: """ return AsyncWriteBatch(self) - def transaction(self, **kwargs) -> AsyncTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> AsyncTransaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for @@ -412,4 +437,4 @@ def transaction(self, **kwargs) -> AsyncTransaction: :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`: A transaction attached to this client. """ - return AsyncTransaction(self, **kwargs) + return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 8c832b8f4c..561111163a 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -15,14 +15,13 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, cast from google.api_core import gapic_v1 from google.api_core import retry_async as retries from google.cloud.firestore_v1 import ( async_aggregation, - async_document, async_query, async_vector_query, transaction, @@ -31,9 +30,10 @@ BaseCollectionReference, _item_to_document_ref, ) -from google.cloud.firestore_v1.document import DocumentReference if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -140,9 +140,7 @@ async def add( write_result = await document_ref.create(document_data, **kwargs) return write_result.update_time, document_ref - def document( - self, document_id: str | None = None - ) -> async_document.AsyncDocumentReference: + def document(self, document_id: str | None = None) -> AsyncDocumentReference: """Create a sub-document underneath the current collection. Args: @@ -155,14 +153,17 @@ def document( :class:`~google.cloud.firestore_v1.document.async_document.AsyncDocumentReference`: The child document. """ - return super(AsyncCollectionReference, self).document(document_id) + doc = super(AsyncCollectionReference, self).document(document_id) + return cast("AsyncDocumentReference", doc) async def list_documents( self, page_size: int | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> AsyncGenerator[DocumentReference, None]: + *, + read_time: datetime.datetime | None = None, + ) -> AsyncGenerator[AsyncDocumentReference, None]: """List all subdocuments of the current collection. Args: @@ -173,6 +174,10 @@ async def list_documents( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: @@ -180,7 +185,9 @@ async def list_documents( collection does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_list_documents(page_size, retry, timeout) + request, kwargs = self._prep_list_documents( + page_size, retry, timeout, read_time + ) iterator = await self._client._firestore_api.list_documents( request=request, @@ -197,6 +204,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in this collection. @@ -216,6 +224,10 @@ async def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). @@ -227,6 +239,8 @@ async def get( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options is not None: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return await query.get(transaction=transaction, **kwargs) @@ -237,6 +251,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[DocumentSnapshot]: """Read the documents in this collection. @@ -268,6 +283,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `AsyncStreamGenerator[DocumentSnapshot]`: A generator of the query @@ -276,5 +295,7 @@ def stream( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return query.stream(transaction=transaction, **kwargs) diff --git a/google/cloud/firestore_v1/async_document.py b/google/cloud/firestore_v1/async_document.py index 78c71b33fc..c3ebfbe0cc 100644 --- a/google/cloud/firestore_v1/async_document.py +++ b/google/cloud/firestore_v1/async_document.py @@ -329,6 +329,8 @@ async def get( transaction=None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -351,6 +353,10 @@ async def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: :class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`: @@ -362,7 +368,9 @@ async def get( """ from google.cloud.firestore_v1.base_client import _parse_batch_get - request, kwargs = self._prep_batch_get(field_paths, transaction, retry, timeout) + request, kwargs = self._prep_batch_get( + field_paths, transaction, retry, timeout, read_time + ) response_iter = await self._client._firestore_api.batch_get_documents( request=request, @@ -397,6 +405,8 @@ async def collections( page_size: int | None = None, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator: """List subcollections of the current document. @@ -408,6 +418,10 @@ async def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.async_collection.AsyncCollectionReference`]: @@ -415,7 +429,7 @@ async def collections( document does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_collections(page_size, retry, timeout) + request, kwargs = self._prep_collections(page_size, retry, timeout, read_time) iterator = await self._client._firestore_api.list_collection_ids( request=request, diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index d4fd45fa46..de6c3c1cf8 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,7 +20,16 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, List, Optional, Type +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Optional, + Type, + Union, + Sequence, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -40,6 +49,8 @@ from google.cloud.firestore_v1.query_results import QueryResultsList if TYPE_CHECKING: # pragma: NO COVER + import datetime + # Types needed only for Type Hints from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_document import DocumentSnapshot @@ -182,6 +193,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -201,6 +213,10 @@ async def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -230,6 +246,7 @@ async def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) try: result_list = [d async for d in result] @@ -248,7 +265,7 @@ async def get( def find_nearest( self, vector_field: str, - query_vector: Vector, + query_vector: Union[Vector, Sequence[float]], limit: int, distance_measure: DistanceMeasure, *, @@ -261,7 +278,7 @@ def find_nearest( Args: vector_field (str): An indexed vector field to search upon. Only documents which contain vectors whose dimensionality match the query_vector can be returned. - query_vector (Vector): The query vector that we are searching on. Must be a vector of no more + query_vector (Vector | Sequence[float]): The query vector that we are searching on. Must be a vector of no more than 2048 dimensions. limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. @@ -336,6 +353,7 @@ async def _make_stream( retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[DocumentSnapshot | query_profile_pb.ExplainMetrics, Any]: """Internal method for stream(). Read the documents in the collection that match this query. @@ -368,6 +386,10 @@ async def _make_stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Yields: [:class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot` \ @@ -381,6 +403,7 @@ async def _make_stream( retry, timeout, explain_options, + read_time, ) response_iterator = await self._client._firestore_api.run_query( @@ -412,6 +435,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncStreamGenerator[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -443,6 +467,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Returns: `AsyncStreamGenerator[DocumentSnapshot]`: @@ -453,6 +481,7 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) return AsyncStreamGenerator(inner_generator, explain_options) @@ -514,6 +543,8 @@ async def get_partitions( partition_count, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[QueryPartition, None]: """Partition a query for parallelization. @@ -529,8 +560,15 @@ async def get_partitions( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - request, kwargs = self._prep_get_partitions(partition_count, retry, timeout) + request, kwargs = self._prep_get_partitions( + partition_count, retry, timeout, read_time + ) + pager = await self._client._firestore_api.partition_query( request=request, metadata=self._client._rpc_metadata, diff --git a/google/cloud/firestore_v1/async_transaction.py b/google/cloud/firestore_v1/async_transaction.py index 038710929b..0dfa82e011 100644 --- a/google/cloud/firestore_v1/async_transaction.py +++ b/google/cloud/firestore_v1/async_transaction.py @@ -15,7 +15,16 @@ """Helpers for applying Google Cloud Firestore changes in a transaction.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Callable, + Generic, + Optional, +) +from typing_extensions import Concatenate, ParamSpec, TypeVar from google.api_core import exceptions, gapic_v1 from google.api_core import retry_async as retries @@ -36,11 +45,17 @@ # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.base_document import DocumentSnapshot from google.cloud.firestore_v1.query_profile import ExplainOptions +T = TypeVar("T") +P = ParamSpec("P") + + class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. @@ -154,6 +169,8 @@ async def get_all( references: list, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> AsyncGenerator[DocumentSnapshot, Any]: """Retrieves multiple documents from Firestore. @@ -164,12 +181,18 @@ async def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time return await self._client.get_all(references, transaction=self, **kwargs) async def get( @@ -179,6 +202,7 @@ async def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]: """ Retrieve a document or a query result from the database. @@ -195,6 +219,10 @@ async def get( Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. Can only be used when running a query, not a document reference. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: DocumentSnapshot: The next document snapshot that fulfills the query, @@ -206,6 +234,8 @@ async def get( reference. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time if isinstance(ref_or_query, AsyncDocumentReference): if explain_options is not None: raise ValueError( @@ -225,7 +255,7 @@ async def get( ) -class _AsyncTransactional(_BaseTransactional): +class _AsyncTransactional(_BaseTransactional, Generic[T, P]): """Provide a callable object to use as a transactional decorater. This is surfaced via @@ -236,12 +266,14 @@ class _AsyncTransactional(_BaseTransactional): A coroutine that should be run (and retried) in a transaction. """ - def __init__(self, to_wrap) -> None: + def __init__( + self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] + ) -> None: super(_AsyncTransactional, self).__init__(to_wrap) async def _pre_commit( - self, transaction: AsyncTransaction, *args, **kwargs - ) -> Coroutine: + self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs + ) -> T: """Begin transaction and call the wrapped coroutine. Args: @@ -254,7 +286,7 @@ async def _pre_commit( along to the wrapped coroutine. Returns: - Any: result of the wrapped coroutine. + T: result of the wrapped coroutine. Raises: Exception: Any failure caused by ``to_wrap``. @@ -269,12 +301,14 @@ async def _pre_commit( self.retry_id = self.current_id return await self.to_wrap(transaction, *args, **kwargs) - async def __call__(self, transaction, *args, **kwargs): + async def __call__( + self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs + ) -> T: """Execute the wrapped callable within a transaction. Args: transaction - (:class:`~google.cloud.firestore_v1.transaction.Transaction`): + (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): A transaction to execute the callable within. args (Tuple[Any, ...]): The extra positional arguments to pass along to the wrapped callable. @@ -282,7 +316,7 @@ async def __call__(self, transaction, *args, **kwargs): along to the wrapped callable. Returns: - Any: The result of the wrapped callable. + T: The result of the wrapped callable. Raises: ValueError: If the transaction does not succeed in @@ -296,7 +330,7 @@ async def __call__(self, transaction, *args, **kwargs): try: for attempt in range(transaction._max_attempts): - result = await self._pre_commit(transaction, *args, **kwargs) + result: T = await self._pre_commit(transaction, *args, **kwargs) try: await transaction._commit() return result @@ -321,17 +355,17 @@ async def __call__(self, transaction, *args, **kwargs): def async_transactional( - to_wrap: Callable[[AsyncTransaction], Any] -) -> _AsyncTransactional: + to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] +) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]: """Decorate a callable so that it runs in a transaction. Args: to_wrap - (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): + (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]): A callable that should be run (and retried) in a transaction. Returns: - Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: + Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]: the wrapped callable. """ return _AsyncTransactional(to_wrap) diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 34a3baad81..c5e6a7b7f6 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -21,6 +21,7 @@ from __future__ import annotations import abc + from abc import ABC from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union @@ -43,6 +44,8 @@ StreamGenerator, ) + import datetime + class AggregationResult(object): """ @@ -80,23 +83,26 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) + self.field_ref: str = field_str super(SumAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb @@ -104,16 +110,18 @@ def _to_protobuf(self): class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): - if isinstance(field_ref, FieldPath): - # convert field path to string - field_ref = field_ref.to_api_repr() - self.field_ref = field_ref + # convert field path to string if needed + field_str = ( + field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref + ) + self.field_ref: str = field_str super(AvgAggregation, self).__init__(alias=alias) def _to_protobuf(self): """Convert this instance to the protobuf representation""" aggregation_pb = StructuredAggregationQuery.Aggregation() - aggregation_pb.alias = self.alias + if self.alias: + aggregation_pb.alias = self.alias aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb @@ -205,6 +213,7 @@ def _prep_stream( retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None, timeout: float | None = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> Tuple[dict, dict]: parent_path, expected_prefix = self._collection_ref._parent_info() request = { @@ -214,6 +223,8 @@ def _prep_stream( } if explain_options: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -228,6 +239,7 @@ def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( QueryResultsList[AggregationResult] | Coroutine[Any, Any, List[List[AggregationResult]]] @@ -253,6 +265,10 @@ def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]): @@ -270,6 +286,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]] @@ -291,6 +308,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index b0d50f1f47..851c7849ff 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -15,7 +15,7 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" from __future__ import annotations import abc -from typing import Dict, Union +from typing import Any, Dict, Union # Types needed only for Type Hints from google.api_core import retry as retries @@ -67,7 +67,9 @@ def commit(self): write depend on the implementing class.""" raise NotImplementedError() - def create(self, reference: BaseDocumentReference, document_data: dict) -> None: + def create( + self, reference: BaseDocumentReference, document_data: dict[str, Any] + ) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -120,7 +122,7 @@ def set( def update( self, reference: BaseDocumentReference, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, ) -> None: """Add a "change" to update a document. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 9b1c0bccd4..f3eeeae496 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -25,6 +25,7 @@ """ from __future__ import annotations +import datetime import os from typing import ( Any, @@ -56,7 +57,7 @@ DocumentSnapshot, ) from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS, BaseTransaction from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client @@ -437,6 +438,7 @@ def _prep_get_all( transaction: BaseTransaction | None = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: float | None = None, + read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict, dict]: """Shared setup for async/sync :meth:`get_all`.""" document_paths, reference_map = _reference_info(references) @@ -447,6 +449,8 @@ def _prep_get_all( "mask": mask, "transaction": _helpers.get_transaction_id(transaction), } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, reference_map, kwargs @@ -458,6 +462,8 @@ def get_all( transaction=None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> Union[ AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any] ]: @@ -467,9 +473,14 @@ def _prep_collections( self, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: float | None = None, + read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" - request = {"parent": "{}/documents".format(self._database_string)} + request: dict[str, Any] = { + "parent": "{}/documents".format(self._database_string), + } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -478,13 +489,17 @@ def collections( self, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ): raise NotImplementedError def batch(self) -> BaseWriteBatch: raise NotImplementedError - def transaction(self, **kwargs) -> BaseTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> BaseTransaction: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 0e5ae6ed1e..be817c5fe9 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -16,6 +16,7 @@ from __future__ import annotations import random + from typing import ( TYPE_CHECKING, Any, @@ -34,6 +35,7 @@ from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.base_query import QueryType if TYPE_CHECKING: # pragma: NO COVER @@ -44,6 +46,7 @@ BaseVectorQuery, DistanceMeasure, ) + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -53,6 +56,8 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.vector_query import VectorQuery + import datetime + _AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" @@ -129,7 +134,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None) -> DocumentReference: + def document(self, document_id: Optional[str] = None) -> BaseDocumentReference: """Create a sub-document underneath the current collection. Args: @@ -139,7 +144,7 @@ def document(self, document_id: Optional[str] = None) -> DocumentReference: uppercase and lowercase and letters. Returns: - :class:`~google.cloud.firestore_v1.document.DocumentReference`: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: The child document. """ if document_id is None: @@ -179,7 +184,7 @@ def _prep_add( document_id: Optional[str] = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, - ) -> Tuple[DocumentReference, dict]: + ): """Shared setup for async / sync :method:`add`""" if document_id is None: document_id = _auto_id() @@ -203,6 +208,7 @@ def _prep_list_documents( page_size: Optional[int] = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, + read_time: Optional[datetime.datetime] = None, ) -> Tuple[dict, dict]: """Shared setup for async / sync :method:`list_documents`""" parent, _ = self._parent_info() @@ -216,6 +222,8 @@ def _prep_list_documents( # to include no fields "mask": {"field_paths": None}, } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -225,8 +233,11 @@ def list_documents( page_size: Optional[int] = None, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, + *, + read_time: Optional[datetime.datetime] = None, ) -> Union[ - Generator[DocumentReference, Any, Any], AsyncGenerator[DocumentReference, Any] + Generator[DocumentReference, Any, Any], + AsyncGenerator[AsyncDocumentReference, Any], ]: raise NotImplementedError @@ -498,6 +509,7 @@ def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( QueryResultsList[DocumentSnapshot] | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]] @@ -511,6 +523,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> StreamGenerator[DocumentSnapshot] | AsyncIterator[DocumentSnapshot]: raise NotImplementedError @@ -602,13 +615,17 @@ def _auto_id() -> str: return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) -def _item_to_document_ref(collection_reference, item) -> DocumentReference: +def _item_to_document_ref(collection_reference, item): """Convert Document resource to document ref. Args: collection_reference (google.api_core.page_iterator.GRPCIterator): iterator response item (dict): document resource + + Returns: + :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: + The child document """ document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] return collection_reference.document(document_id) diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index b16b8abace..fe6113bfc3 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -16,6 +16,7 @@ from __future__ import annotations import copy + from typing import ( TYPE_CHECKING, Any, @@ -37,6 +38,8 @@ if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.types import Document, firestore, write + import datetime + class BaseDocumentReference(object): """A reference to a document in a Firestore database. @@ -290,6 +293,7 @@ def _prep_batch_get( transaction=None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, + read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`get`.""" if isinstance(field_paths, str): @@ -306,6 +310,8 @@ def _prep_batch_get( "mask": mask, "transaction": _helpers.get_transaction_id(transaction), } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -316,6 +322,8 @@ def get( transaction=None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> "DocumentSnapshot" | Awaitable["DocumentSnapshot"]: raise NotImplementedError @@ -324,9 +332,15 @@ def _prep_collections( page_size: int | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, + read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: """Shared setup for async/sync :meth:`collections`.""" - request = {"parent": self._document_path, "page_size": page_size} + request = { + "parent": self._document_path, + "page_size": page_size, + } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -336,6 +350,8 @@ def collections( page_size: int | None = None, retry: retries.Retry | retries.AsyncRetry | None | object = None, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ): raise NotImplementedError @@ -402,7 +418,7 @@ def _client(self): return self._reference._client @property - def exists(self): + def exists(self) -> bool: """Existence flag. Indicates if the document existed at the time this snapshot @@ -414,7 +430,7 @@ def exists(self): return self._exists @property - def id(self): + def id(self) -> str: """The document identifier (within its collection). Returns: @@ -423,7 +439,7 @@ def id(self): return self._reference.id @property - def reference(self): + def reference(self) -> BaseDocumentReference: """Document reference corresponding to document that owns this data. Returns: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2fb9bd895d..2de95b79ad 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -24,6 +24,7 @@ import copy import math import warnings + from typing import ( TYPE_CHECKING, Any, @@ -66,6 +67,8 @@ from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator + import datetime + _BAD_DIR_STRING: str _BAD_OP_NAN: str @@ -179,7 +182,7 @@ def _validate_opation(op_string, value): class FieldFilter(BaseFilter): """Class representation of a Field Filter.""" - def __init__(self, field_path, op_string, value=None): + def __init__(self, field_path: str, op_string: str, value: Any | None = None): self.field_path = field_path self.value = value self.op_string = _validate_opation(op_string, value) @@ -205,8 +208,8 @@ class BaseCompositeFilter(BaseFilter): def __init__( self, - operator=StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, - filters=None, + operator: int = StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters: list[BaseFilter] | None = None, ): self.operator = operator if filters is None: @@ -238,7 +241,7 @@ def _to_pb(self): class Or(BaseCompositeFilter): """Class representation of an OR Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.OR, filters=filters ) @@ -247,7 +250,7 @@ def __init__(self, filters): class And(BaseCompositeFilter): """Class representation of an AND Filter.""" - def __init__(self, filters): + def __init__(self, filters: list[BaseFilter]): super().__init__( operator=StructuredQuery.CompositeFilter.Operator.AND, filters=filters ) @@ -926,7 +929,7 @@ def _normalize_cursor(self, cursor, orders) -> Tuple[List, bool] | None: if isinstance(document_fields, document.DocumentSnapshot): snapshot = document_fields - document_fields = snapshot.to_dict() + document_fields = copy.deepcopy(snapshot._data) document_fields["__name__"] = snapshot.reference if isinstance(document_fields, dict): @@ -1032,6 +1035,7 @@ def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( QueryResultsList[DocumentSnapshot] | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]] @@ -1044,6 +1048,7 @@ def _prep_stream( retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> Tuple[dict, str, dict]: """Shared setup for async / sync :meth:`stream`""" if self._limit_to_last: @@ -1060,6 +1065,8 @@ def _prep_stream( } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, expected_prefix, kwargs @@ -1071,6 +1078,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( StreamGenerator[document.DocumentSnapshot] | AsyncStreamGenerator[DocumentSnapshot] @@ -1427,6 +1435,7 @@ def _prep_get_partitions( partition_count, retry: retries.Retry | object | None = None, timeout: float | None = None, + read_time: datetime.datetime | None = None, ) -> Tuple[dict, dict]: self._validate_partition_query() parent_path, expected_prefix = self._parent._parent_info() @@ -1443,6 +1452,8 @@ def _prep_get_partitions( "structured_query": query._to_protobuf(), "partition_count": partition_count, } + if read_time is not None: + request["read_time"] = read_time kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) return request, kwargs @@ -1452,6 +1463,8 @@ def get_partitions( partition_count, retry: Optional[retries.Retry] = None, timeout: Optional[float] = None, + *, + read_time: Optional[datetime.datetime] = None, ): raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_transaction.py b/google/cloud/firestore_v1/base_transaction.py index 92e54c81c4..297c3f572e 100644 --- a/google/cloud/firestore_v1/base_transaction.py +++ b/google/cloud/firestore_v1/base_transaction.py @@ -37,6 +37,8 @@ from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.types import write as write_pb + import datetime + MAX_ATTEMPTS = 5 """int: Default number of transaction attempts (with retries).""" @@ -148,6 +150,8 @@ def get_all( references: list, retry: retries.Retry | retries.AsyncRetry | object | None = None, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> ( Generator[DocumentSnapshot, Any, None] | Coroutine[Any, Any, AsyncGenerator[DocumentSnapshot, Any]] @@ -161,6 +165,7 @@ def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> ( StreamGenerator[DocumentSnapshot] | Generator[DocumentSnapshot, Any, None] diff --git a/google/cloud/firestore_v1/bulk_writer.py b/google/cloud/firestore_v1/bulk_writer.py index eff936300d..6747bc234b 100644 --- a/google/cloud/firestore_v1/bulk_writer.py +++ b/google/cloud/firestore_v1/bulk_writer.py @@ -110,7 +110,7 @@ def wrapper(self, *args, **kwargs): # For code parity, even `SendMode.serial` scenarios should return # a future here. Anything else would badly complicate calling code. result = fn(self, *args, **kwargs) - future = concurrent.futures.Future() + future: concurrent.futures.Future = concurrent.futures.Future() future.set_result(result) return future @@ -319,6 +319,7 @@ def __init__( self._total_batches_sent: int = 0 self._total_write_operations: int = 0 + self._executor: concurrent.futures.ThreadPoolExecutor self._ensure_executor() @staticmethod diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index 23c6b36ef2..54943aded4 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -39,6 +39,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference @@ -50,8 +51,9 @@ ) from google.cloud.firestore_v1.transaction import Transaction -if TYPE_CHECKING: - from google.cloud.firestore_v1.bulk_writer import BulkWriter # pragma: NO COVER +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.bulk_writer import BulkWriter + import datetime class Client(BaseClient): @@ -205,6 +207,8 @@ def get_all( transaction: Transaction | None = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieve a batch of documents. @@ -239,13 +243,17 @@ def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ request, reference_map, kwargs = self._prep_get_all( - references, field_paths, transaction, retry, timeout + references, field_paths, transaction, retry, timeout, read_time ) response_iterator = self._firestore_api.batch_get_documents( @@ -261,6 +269,8 @@ def collections( self, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> Generator[Any, Any, None]: """List top-level collections of the client's database. @@ -269,12 +279,16 @@ def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: iterator of subcollections of the current document. """ - request, kwargs = self._prep_collections(retry, timeout) + request, kwargs = self._prep_collections(retry, timeout, read_time) iterator = self._firestore_api.list_collection_ids( request=request, @@ -378,7 +392,9 @@ def batch(self) -> WriteBatch: """ return WriteBatch(self) - def transaction(self, **kwargs) -> Transaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> Transaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.transaction.Transaction` for @@ -394,4 +410,4 @@ def transaction(self, **kwargs) -> Transaction: :class:`~google.cloud.firestore_v1.transaction.Transaction`: A transaction attached to this client. """ - return Transaction(self, **kwargs) + return Transaction(self, max_attempts=max_attempts, read_only=read_only) diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index cd6929b688..60788dd71e 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -35,6 +35,8 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.stream_generator import StreamGenerator + import datetime + class CollectionReference(BaseCollectionReference[query_mod.Query]): """A reference to a collection in a Firestore database. @@ -137,6 +139,8 @@ def list_documents( page_size: Union[int, None] = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: Union[float, None] = None, + *, + read_time: Optional[datetime.datetime] = None, ) -> Generator[Any, Any, None]: """List all subdocuments of the current collection. @@ -148,6 +152,10 @@ def list_documents( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.DocumentReference`]: @@ -155,7 +163,9 @@ def list_documents( collection does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_list_documents(page_size, retry, timeout) + request, kwargs = self._prep_list_documents( + page_size, retry, timeout, read_time + ) iterator = self._client._firestore_api.list_documents( request=request, @@ -174,6 +184,7 @@ def get( timeout: Union[float, None] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in this collection. @@ -192,6 +203,10 @@ def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -204,6 +219,8 @@ def get( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options is not None: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return query.get(transaction=transaction, **kwargs) @@ -214,6 +231,7 @@ def stream( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> StreamGenerator[DocumentSnapshot]: """Read the documents in this collection. @@ -245,6 +263,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: `StreamGenerator[DocumentSnapshot]`: A generator of the query results. @@ -252,6 +274,8 @@ def stream( query, kwargs = self._prep_get_or_stream(retry, timeout) if explain_options: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time return query.stream(transaction=transaction, **kwargs) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 0c7d7872fd..4bb6399a7c 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -169,7 +169,7 @@ def set( def update( self, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, @@ -365,6 +365,8 @@ def get( transaction=None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> DocumentSnapshot: """Retrieve a snapshot of the current document. @@ -387,6 +389,10 @@ def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: :class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot`: @@ -398,7 +404,9 @@ def get( """ from google.cloud.firestore_v1.base_client import _parse_batch_get - request, kwargs = self._prep_batch_get(field_paths, transaction, retry, timeout) + request, kwargs = self._prep_batch_get( + field_paths, transaction, retry, timeout, read_time + ) response_iter = self._client._firestore_api.batch_get_documents( request=request, @@ -434,6 +442,8 @@ def collections( page_size: int | None = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> Generator[Any, Any, None]: """List subcollections of the current document. @@ -445,6 +455,10 @@ def collections( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Returns: Sequence[:class:`~google.cloud.firestore_v1.collection.CollectionReference`]: @@ -452,7 +466,7 @@ def collections( document does not exist at the time of `snapshot`, the iterator will be empty """ - request, kwargs = self._prep_collections(page_size, retry, timeout) + request, kwargs = self._prep_collections(page_size, retry, timeout, read_time) iterator = self._client._firestore_api.list_collection_ids( request=request, diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 048eb64d08..27ac6cc459 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -263,7 +263,7 @@ class FieldPath(object): Indicating path of the key to be used. """ - def __init__(self, *parts): + def __init__(self, *parts: str): for part in parts: if not isinstance(part, str) or not part: error = "One or more components is not a string or is empty." @@ -271,7 +271,7 @@ def __init__(self, *parts): self.parts = tuple(parts) @classmethod - def from_api_repr(cls, api_repr: str): + def from_api_repr(cls, api_repr: str) -> "FieldPath": """Factory: create a FieldPath from the string formatted per the API. Args: @@ -288,7 +288,7 @@ def from_api_repr(cls, api_repr: str): return cls(*parse_field_path(api_repr)) @classmethod - def from_string(cls, path_string: str): + def from_string(cls, path_string: str) -> "FieldPath": """Factory: create a FieldPath from a unicode string representation. This method splits on the character `.` and disallows the @@ -351,7 +351,7 @@ def __add__(self, other): else: return NotImplemented - def to_api_repr(self): + def to_api_repr(self) -> str: """Render a quoted string representation of the FieldPath Returns: @@ -360,7 +360,7 @@ def to_api_repr(self): """ return render_field_path(self.parts) - def eq_or_parent(self, other): + def eq_or_parent(self, other) -> bool: """Check whether ``other`` is an ancestor. Returns: @@ -369,7 +369,7 @@ def eq_or_parent(self, other): """ return self.parts[: len(other.parts)] == other.parts[: len(self.parts)] - def lineage(self): + def lineage(self) -> set["FieldPath"]: """Return field paths for all parents. Returns: Set[:class:`FieldPath`] @@ -378,7 +378,7 @@ def lineage(self): return {FieldPath(*self.parts[:index]) for index in indexes} @staticmethod - def document_id(): + def document_id() -> str: """A special FieldPath value to refer to the ID of a document. It can be used in queries to sort or filter by the document ID. diff --git a/google/cloud/firestore_v1/gapic_metadata.json b/google/cloud/firestore_v1/gapic_metadata.json index d0462f9640..03a6e428b4 100644 --- a/google/cloud/firestore_v1/gapic_metadata.json +++ b/google/cloud/firestore_v1/gapic_metadata.json @@ -40,6 +40,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -125,6 +130,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -210,6 +220,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" diff --git a/google/cloud/firestore_v1/gapic_version.py b/google/cloud/firestore_v1/gapic_version.py index e546bae053..ced4e0faf0 100644 --- a/google/cloud/firestore_v1/gapic_version.py +++ b/google/cloud/firestore_v1/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index a8b821bdc4..8b6018b6a5 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -59,6 +59,8 @@ from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions + import datetime + class Query(BaseQuery): """Represents a query to the Firestore API. @@ -151,6 +153,7 @@ def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> QueryResultsList[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -172,6 +175,10 @@ def get( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Returns: QueryResultsList[DocumentSnapshot]: The documents in the collection @@ -198,6 +205,7 @@ def get( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) result_list = list(result) if is_limited_to_last: @@ -248,13 +256,12 @@ def _chunkify( ): return - def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None): + def _get_stream_iterator( + self, transaction, retry, timeout, explain_options=None, read_time=None + ): """Helper method for :meth:`stream`.""" request, expected_prefix, kwargs = self._prep_stream( - transaction, - retry, - timeout, - explain_options, + transaction, retry, timeout, explain_options, read_time ) response_iterator = self._client._firestore_api.run_query( @@ -363,6 +370,7 @@ def _make_stream( retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> Generator[DocumentSnapshot, Any, Optional[ExplainMetrics]]: """Internal method for stream(). Read the documents in the collection that match this query. @@ -396,6 +404,10 @@ def _make_stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Yields: DocumentSnapshot: @@ -412,6 +424,7 @@ def _make_stream( retry, timeout, explain_options, + read_time, ) last_snapshot = None @@ -426,6 +439,7 @@ def _make_stream( transaction, retry, timeout, + read_time=read_time, ) continue else: @@ -458,6 +472,7 @@ def stream( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> StreamGenerator[DocumentSnapshot]: """Read the documents in the collection that match this query. @@ -489,6 +504,10 @@ def stream( (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. Returns: `StreamGenerator[DocumentSnapshot]`: A generator of the query results. @@ -498,6 +517,7 @@ def stream( retry=retry, timeout=timeout, explain_options=explain_options, + read_time=read_time, ) return StreamGenerator(inner_generator, explain_options) @@ -590,6 +610,8 @@ def get_partitions( partition_count, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: Optional[datetime.datetime] = None, ) -> Generator[QueryPartition, None, None]: """Partition a query for parallelization. @@ -605,8 +627,14 @@ def get_partitions( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - request, kwargs = self._prep_get_partitions(partition_count, retry, timeout) + request, kwargs = self._prep_get_partitions( + partition_count, retry, timeout, read_time + ) pager = self._client._firestore_api.partition_query( request=request, diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index b904229b04..3557eb94c9 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -53,6 +53,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -1248,6 +1249,109 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 8055612429..ac86aaa9ec 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -68,6 +68,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -168,6 +169,34 @@ def _get_default_mtls_endpoint(api_endpoint): _DEFAULT_ENDPOINT_TEMPLATE = "firestore.{UNIVERSE_DOMAIN}" _DEFAULT_UNIVERSE = "googleapis.com" + @staticmethod + def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS if the + google-auth version supports should_use_client_cert automatic mTLS enablement. + + Alternatively, read from the GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS + Raises: + ValueError: (If using a version of google-auth without should_use_client_cert and + GOOGLE_API_USE_CLIENT_CERTIFICATE is set to an unexpected value.) + """ + # check if google-auth version supports should_use_client_cert for automatic mTLS enablement + if hasattr(mtls, "should_use_client_cert"): # pragma: NO COVER + return mtls.should_use_client_cert() + else: # pragma: NO COVER + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -333,12 +362,8 @@ def get_mtls_endpoint_and_cert_source( ) if client_options is None: client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_client_cert = FirestoreClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" @@ -346,7 +371,7 @@ def get_mtls_endpoint_and_cert_source( # Figure out the client cert source to use. client_cert_source = None - if use_client_cert == "true": + if use_client_cert: if client_options.client_cert_source: client_cert_source = client_options.client_cert_source elif mtls.has_default_client_cert_source(): @@ -378,20 +403,14 @@ def _read_environment_variables(): google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT is not any of ["auto", "never", "always"]. """ - use_client_cert = os.getenv( - "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" - ).lower() + use_client_cert = FirestoreClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + return use_client_cert, use_mtls_endpoint, universe_domain_env @staticmethod def _get_client_cert_source(provided_cert_source, use_cert_flag): @@ -1631,6 +1650,107 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 66d81748cd..905dded096 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -75,9 +75,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. @@ -290,6 +291,23 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_retry=retries.Retry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( @@ -513,6 +531,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index c302a73c28..f057d16e30 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -164,9 +164,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if a ``channel`` instance is provided. channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): @@ -299,9 +300,10 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -571,6 +573,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index f461622962..cf6006672f 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -161,8 +161,9 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. + credentials_file (Optional[str]): Deprecated. A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -213,9 +214,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -587,6 +589,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, @@ -962,6 +992,23 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, + client_info=client_info, + ), self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 8c038348c7..845569d97e 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -127,6 +127,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -445,6 +453,56 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. + """ + return response + + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( self, request: firestore.GetDocumentRequest, @@ -945,9 +1003,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if ``channel`` is provided. This argument will be + removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client @@ -1118,6 +1177,22 @@ def __call__( resp, _ = self._interceptor.post_batch_get_documents_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.batch_get_documents", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BatchGetDocuments", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _BatchWrite(_BaseFirestoreRestTransport._BaseBatchWrite, FirestoreRestStub): @@ -1856,6 +1931,158 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) + + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.execute_pipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): return hash("FirestoreRestTransport.GetDocument") @@ -2736,6 +2963,22 @@ def __call__( resp, _ = self._interceptor.post_run_aggregation_query_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.run_aggregation_query", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunAggregationQuery", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _RunQuery(_BaseFirestoreRestTransport._BaseRunQuery, FirestoreRestStub): @@ -2866,6 +3109,22 @@ def __call__( resp, _ = self._interceptor.post_run_query_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.run_query", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunQuery", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _UpdateDocument( @@ -3094,6 +3353,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 1d95cd16ea..80ce35e495 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -426,6 +426,63 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetDocument: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/firestore_v1/transaction.py b/google/cloud/firestore_v1/transaction.py index 37afd5fb00..913fc1d3bc 100644 --- a/google/cloud/firestore_v1/transaction.py +++ b/google/cloud/firestore_v1/transaction.py @@ -40,6 +40,8 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.stream_generator import StreamGenerator + import datetime + class Transaction(batch.WriteBatch, BaseTransaction): """Accumulate read-and-write operations to be sent in a transaction. @@ -154,6 +156,8 @@ def get_all( references: list, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, + *, + read_time: datetime.datetime | None = None, ) -> Generator[DocumentSnapshot, Any, None]: """Retrieves multiple documents from Firestore. @@ -164,12 +168,18 @@ def get_all( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the query, or :data:`None` if the document does not exist. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time return self._client.get_all(references, transaction=self, **kwargs) def get( @@ -179,6 +189,7 @@ def get( timeout: Optional[float] = None, *, explain_options: Optional[ExplainOptions] = None, + read_time: Optional[datetime.datetime] = None, ) -> StreamGenerator[DocumentSnapshot] | Generator[DocumentSnapshot, Any, None]: """Retrieve a document or a query result from the database. @@ -194,6 +205,10 @@ def get( Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. Can only be used when running a query, not a document reference. + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery + is enabled, can additionally be a whole minute timestamp within the past 7 days. If no + timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. Yields: .DocumentSnapshot: The next document snapshot that fulfills the @@ -205,6 +220,8 @@ def get( reference. """ kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time if isinstance(ref_or_query, DocumentReference): if explain_options is not None: raise ValueError( diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index ae1004e132..ed1965d7ff 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 0942354f50..8073ad97aa 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -72,7 +74,7 @@ class Document(proto.Message): may contain any character. Some characters, including :literal:`\``, must be escaped using a ``\``. For example, :literal:`\`x&y\`` represents ``x&y`` and - :literal:`\`bak\`tik\`` represents :literal:`bak`tik`. + :literal:`\`bak\\`tik\`` represents :literal:`bak`tik`. create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. The time at which the document was created. @@ -183,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ @@ -246,6 +279,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +335,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 0000000000..b0f9421ba7 --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Pipeline explain stats. + + Depending on the explain options in the original request, this + can contain the optimized plan and / or execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + Currently there are two supported options: ``TEXT`` and + ``JSON``. Both supply a ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 53a6c6e7af..4e53ba3137 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,151 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this + includes those like + [``__name__``][google.firestore.v1.Document.name] and + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the results are valid. + + This is a (not strictly) monotonically increasing value + across multiple responses in the same stream. The API + guarantees that all previously returned results are still + valid at the latest ``execution_time``. This allows the API + consumer to treat the query if it ran at the latest + ``execution_time`` returned. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + + This is present on the **last** response if the request + configured explain to run in 'analyze' or 'explain' mode in + the pipeline options. If the query does not return any + results, a response with ``explain_stats`` and no + ``results`` will still be sent. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. @@ -1017,8 +1166,8 @@ class PartitionQueryRequest(proto.Message): For example, two subsequent calls using a page_token may return: - - cursor B, cursor M, cursor Q - - cursor A, cursor U, cursor W + - cursor B, cursor M, cursor Q + - cursor A, cursor U, cursor W To obtain a complete result set ordered with respect to the results of the query supplied to PartitionQuery, the results @@ -1092,9 +1241,9 @@ class PartitionQueryResponse(proto.Message): cursors A and B, running the following three queries will return the entire result set of the original query: - - query, end_at A - - query, start_at A, end_at B - - query, start_at B + - query, end_at A + - query, start_at A, end_at B + - query, start_at B An empty result may indicate that the query has too few results to be partitioned, or that the query is not yet @@ -1416,9 +1565,9 @@ class Target(proto.Message): Note that if the client sends multiple ``AddTarget`` requests without an ID, the order of IDs returned in - ``TargetChage.target_ids`` are undefined. Therefore, clients - should provide a target ID instead of relying on the server - to assign one. + ``TargetChange.target_ids`` are undefined. Therefore, + clients should provide a target ID instead of relying on the + server to assign one. If ``target_id`` is non-zero, there must not be an existing active target on this stream with the same ID. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 0000000000..07688dda72 --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans and executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index 9aa8977ddb..d507427858 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -66,25 +66,25 @@ class StructuredQuery(proto.Message): Firestore guarantees a stable ordering through the following rules: - - The ``order_by`` is required to reference all fields used - with an inequality filter. - - All fields that are required to be in the ``order_by`` - but are not already present are appended in - lexicographical ordering of the field name. - - If an order on ``__name__`` is not specified, it is - appended by default. + - The ``order_by`` is required to reference all fields used + with an inequality filter. + - All fields that are required to be in the ``order_by`` but + are not already present are appended in lexicographical + ordering of the field name. + - If an order on ``__name__`` is not specified, it is + appended by default. Fields are appended with the same sort direction as the last order specified, or 'ASCENDING' if no order was specified. For example: - - ``ORDER BY a`` becomes ``ORDER BY a ASC, __name__ ASC`` - - ``ORDER BY a DESC`` becomes - ``ORDER BY a DESC, __name__ DESC`` - - ``WHERE a > 1`` becomes - ``WHERE a > 1 ORDER BY a ASC, __name__ ASC`` - - ``WHERE __name__ > ... AND a > 1`` becomes - ``WHERE __name__ > ... AND a > 1 ORDER BY a ASC, __name__ ASC`` + - ``ORDER BY a`` becomes ``ORDER BY a ASC, __name__ ASC`` + - ``ORDER BY a DESC`` becomes + ``ORDER BY a DESC, __name__ DESC`` + - ``WHERE a > 1`` becomes + ``WHERE a > 1 ORDER BY a ASC, __name__ ASC`` + - ``WHERE __name__ > ... AND a > 1`` becomes + ``WHERE __name__ > ... AND a > 1 ORDER BY a ASC, __name__ ASC`` start_at (google.cloud.firestore_v1.types.Cursor): A potential prefix of a position in the result set to start the query at. @@ -106,10 +106,10 @@ class StructuredQuery(proto.Message): Continuing off the example above, attaching the following start cursors will have varying impact: - - ``START BEFORE (2, /k/123)``: start the query right - before ``a = 1 AND b > 2 AND __name__ > /k/123``. - - ``START AFTER (10)``: start the query right after - ``a = 1 AND b > 10``. + - ``START BEFORE (2, /k/123)``: start the query right before + ``a = 1 AND b > 2 AND __name__ > /k/123``. + - ``START AFTER (10)``: start the query right after + ``a = 1 AND b > 10``. Unlike ``OFFSET`` which requires scanning over the first N results to skip, a start cursor allows the query to begin at @@ -119,8 +119,8 @@ class StructuredQuery(proto.Message): Requires: - - The number of values cannot be greater than the number of - fields specified in the ``ORDER BY`` clause. + - The number of values cannot be greater than the number of + fields specified in the ``ORDER BY`` clause. end_at (google.cloud.firestore_v1.types.Cursor): A potential prefix of a position in the result set to end the query at. @@ -130,8 +130,8 @@ class StructuredQuery(proto.Message): Requires: - - The number of values cannot be greater than the number of - fields specified in the ``ORDER BY`` clause. + - The number of values cannot be greater than the number of + fields specified in the ``ORDER BY`` clause. offset (int): The number of documents to skip before returning the first result. @@ -142,8 +142,8 @@ class StructuredQuery(proto.Message): Requires: - - The value must be greater than or equal to zero if - specified. + - The value must be greater than or equal to zero if + specified. limit (google.protobuf.wrappers_pb2.Int32Value): The maximum number of results to return. @@ -151,8 +151,8 @@ class StructuredQuery(proto.Message): Requires: - - The value must be greater than or equal to zero if - specified. + - The value must be greater than or equal to zero if + specified. find_nearest (google.cloud.firestore_v1.types.StructuredQuery.FindNearest): Optional. A potential nearest neighbors search. @@ -256,7 +256,7 @@ class CompositeFilter(proto.Message): Requires: - - At least one filter is present. + - At least one filter is present. """ class Operator(proto.Enum): @@ -310,27 +310,27 @@ class Operator(proto.Enum): Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. LESS_THAN_OR_EQUAL (2): The given ``field`` is less than or equal to the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. GREATER_THAN (3): The given ``field`` is greater than the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. GREATER_THAN_OR_EQUAL (4): The given ``field`` is greater than or equal to the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. EQUAL (5): The given ``field`` is equal to the given ``value``. NOT_EQUAL (6): @@ -338,9 +338,9 @@ class Operator(proto.Enum): Requires: - - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. ARRAY_CONTAINS (7): The given ``field`` is an array that contains the given ``value``. @@ -350,31 +350,31 @@ class Operator(proto.Enum): Requires: - - That ``value`` is a non-empty ``ArrayValue``, subject to - disjunction limits. - - No ``NOT_IN`` filters in the same query. + - That ``value`` is a non-empty ``ArrayValue``, subject to + disjunction limits. + - No ``NOT_IN`` filters in the same query. ARRAY_CONTAINS_ANY (9): The given ``field`` is an array that contains any of the values in the given array. Requires: - - That ``value`` is a non-empty ``ArrayValue``, subject to - disjunction limits. - - No other ``ARRAY_CONTAINS_ANY`` filters within the same - disjunction. - - No ``NOT_IN`` filters in the same query. + - That ``value`` is a non-empty ``ArrayValue``, subject to + disjunction limits. + - No other ``ARRAY_CONTAINS_ANY`` filters within the same + disjunction. + - No ``NOT_IN`` filters in the same query. NOT_IN (10): The value of the ``field`` is not in the given array. Requires: - - That ``value`` is a non-empty ``ArrayValue`` with at most - 10 values. - - No other ``OR``, ``IN``, ``ARRAY_CONTAINS_ANY``, - ``NOT_IN``, ``NOT_EQUAL``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - That ``value`` is a non-empty ``ArrayValue`` with at most + 10 values. + - No other ``OR``, ``IN``, ``ARRAY_CONTAINS_ANY``, + ``NOT_IN``, ``NOT_EQUAL``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. """ OPERATOR_UNSPECIFIED = 0 LESS_THAN = 1 @@ -433,17 +433,17 @@ class Operator(proto.Enum): Requires: - - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. IS_NOT_NULL (5): The given ``field`` is not equal to ``NULL``. Requires: - - A single ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - A single ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. """ OPERATOR_UNSPECIFIED = 0 IS_NAN = 2 @@ -493,9 +493,9 @@ class FieldReference(proto.Message): Requires: - - MUST be a dot-delimited (``.``) string of segments, where - each segment conforms to [document field - name][google.firestore.v1.Document.fields] limitations. + - MUST be a dot-delimited (``.``) string of segments, where + each segment conforms to [document field + name][google.firestore.v1.Document.fields] limitations. """ field_path: str = proto.Field( @@ -555,9 +555,9 @@ class FindNearest(proto.Message): when the vectors are more similar, the comparison is inverted. - - For EUCLIDEAN, COSINE: WHERE distance <= - distance_threshold - - For DOT_PRODUCT: WHERE distance >= distance_threshold + - For EUCLIDEAN, COSINE: + ``WHERE distance <= distance_threshold`` + - For DOT_PRODUCT: ``WHERE distance >= distance_threshold`` """ class DistanceMeasure(proto.Enum): @@ -688,8 +688,8 @@ class StructuredAggregationQuery(proto.Message): Requires: - - A minimum of one and maximum of five aggregations per - query. + - A minimum of one and maximum of five aggregations per + query. """ class Aggregation(proto.Message): @@ -749,9 +749,9 @@ class Aggregation(proto.Message): Requires: - - Must be unique across all aggregation aliases. - - Conform to [document field - name][google.firestore.v1.Document.fields] limitations. + - Must be unique across all aggregation aliases. + - Conform to [document field + name][google.firestore.v1.Document.fields] limitations. """ class Count(proto.Message): @@ -778,7 +778,7 @@ class Count(proto.Message): Requires: - - Must be greater than zero when present. + - Must be greater than zero when present. """ up_to: wrappers_pb2.Int64Value = proto.Field( @@ -790,26 +790,26 @@ class Count(proto.Message): class Sum(proto.Message): r"""Sum of the values of the requested field. - - Only numeric values will be aggregated. All non-numeric values - including ``NULL`` are skipped. + - Only numeric values will be aggregated. All non-numeric values + including ``NULL`` are skipped. - - If the aggregated values contain ``NaN``, returns ``NaN``. - Infinity math follows IEEE-754 standards. + - If the aggregated values contain ``NaN``, returns ``NaN``. + Infinity math follows IEEE-754 standards. - - If the aggregated value set is empty, returns 0. + - If the aggregated value set is empty, returns 0. - - Returns a 64-bit integer if all aggregated numbers are integers - and the sum result does not overflow. Otherwise, the result is - returned as a double. Note that even if all the aggregated values - are integers, the result is returned as a double if it cannot fit - within a 64-bit signed integer. When this occurs, the returned - value will lose precision. + - Returns a 64-bit integer if all aggregated numbers are integers + and the sum result does not overflow. Otherwise, the result is + returned as a double. Note that even if all the aggregated values + are integers, the result is returned as a double if it cannot fit + within a 64-bit signed integer. When this occurs, the returned + value will lose precision. - - When underflow occurs, floating-point aggregation is - non-deterministic. This means that running the same query - repeatedly without any changes to the underlying values could - produce slightly different results each time. In those cases, - values should be stored as integers over floating-point numbers. + - When underflow occurs, floating-point aggregation is + non-deterministic. This means that running the same query + repeatedly without any changes to the underlying values could + produce slightly different results each time. In those cases, + values should be stored as integers over floating-point numbers. Attributes: field (google.cloud.firestore_v1.types.StructuredQuery.FieldReference): @@ -825,15 +825,15 @@ class Sum(proto.Message): class Avg(proto.Message): r"""Average of the values of the requested field. - - Only numeric values will be aggregated. All non-numeric values - including ``NULL`` are skipped. + - Only numeric values will be aggregated. All non-numeric values + including ``NULL`` are skipped. - - If the aggregated values contain ``NaN``, returns ``NaN``. - Infinity math follows IEEE-754 standards. + - If the aggregated values contain ``NaN``, returns ``NaN``. + Infinity math follows IEEE-754 standards. - - If the aggregated value set is empty, returns ``NULL``. + - If the aggregated value set is empty, returns ``NULL``. - - Always returns the result as a double. + - Always returns the result as a double. Attributes: field (google.cloud.firestore_v1.types.StructuredQuery.FieldReference): diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 79933aecae..9714856559 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import functools @@ -232,7 +233,7 @@ def __init__( def _init_stream(self): rpc_request = self._get_rpc_request - self._rpc = ResumableBidiRpc( + self._rpc: ResumableBidiRpc | None = ResumableBidiRpc( start_rpc=self._api._transport.listen, should_recover=_should_recover, should_terminate=_should_terminate, @@ -243,7 +244,9 @@ def _init_stream(self): self._rpc.add_done_callback(self._on_rpc_done) # The server assigns and updates the resume token. - self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot) + self._consumer: BackgroundConsumer | None = BackgroundConsumer( + self._rpc, self.on_snapshot + ) self._consumer.start() @classmethod @@ -330,16 +333,18 @@ def close(self, reason=None): return # Stop consuming messages. - if self.is_active: - _LOGGER.debug("Stopping consumer.") - self._consumer.stop() - self._consumer._on_response = None + if self._consumer: + if self.is_active: + _LOGGER.debug("Stopping consumer.") + self._consumer.stop() + self._consumer._on_response = None self._consumer = None self._snapshot_callback = None - self._rpc.close() - self._rpc._initial_request = None - self._rpc._callbacks = [] + if self._rpc: + self._rpc.close() + self._rpc._initial_request = None + self._rpc._callbacks = [] self._rpc = None self._closed = True _LOGGER.debug("Finished stopping manager.") @@ -460,13 +465,13 @@ def on_snapshot(self, proto): message = f"Unknown target change type: {target_change_type}" _LOGGER.info(f"on_snapshot: {message}") self.close(reason=ValueError(message)) - - try: - # Use 'proto' vs 'pb' for datetime handling - meth(self, proto.target_change) - except Exception as exc2: - _LOGGER.debug(f"meth(proto) exc: {exc2}") - raise + else: + try: + # Use 'proto' vs 'pb' for datetime handling + meth(self, proto.target_change) + except Exception as exc2: + _LOGGER.debug(f"meth(proto) exc: {exc2}") + raise # NOTE: # in other implementations, such as node, the backoff is reset here diff --git a/librarian.py b/librarian.py new file mode 100644 index 0000000000..ec92a93451 --- /dev/null +++ b/librarian.py @@ -0,0 +1,118 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script is used to synthesize generated parts of this library.""" +from pathlib import Path +from typing import List, Optional + +import synthtool as s +from synthtool import gcp +from synthtool.languages import python + +common = gcp.CommonTemplates() + +# This library ships clients for 3 different APIs, +# firestore, firestore_admin and firestore_bundle. +# firestore_bundle is not versioned +firestore_default_version = "v1" +firestore_admin_default_version = "v1" + +def update_fixup_scripts(path): + # Add message for missing 'libcst' dependency + s.replace( + library / "scripts" / path, + """import libcst as cst""", + """try: + import libcst as cst +except ImportError: + raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') + + + """, + ) + +for library in s.get_staging_dirs(default_version=firestore_default_version): + s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) + s.move(library / f"tests/", f"tests") + fixup_script_path = "fixup_firestore_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) + +for library in s.get_staging_dirs(default_version=firestore_admin_default_version): + s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) + s.move(library / f"tests", f"tests") + fixup_script_path = "fixup_firestore_admin_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) + +for library in s.get_staging_dirs(): + s.replace( + library / "google/cloud/bundle/types/bundle.py", + "from google.firestore.v1 import document_pb2 # type: ignore\n" + "from google.firestore.v1 import query_pb2 # type: ignore", + "from google.cloud.firestore_v1.types import document as document_pb2 # type: ignore\n" + "from google.cloud.firestore_v1.types import query as query_pb2 # type: ignore" + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "from .types.bundle import BundleMetadata\n" + "from .types.bundle import NamedQuery\n", + "from .types.bundle import BundleMetadata\n" + "from .types.bundle import NamedQuery\n" + "\n" + "from .bundle import FirestoreBundle\n", + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "from google.cloud.bundle import gapic_version as package_version\n", + "from google.cloud.firestore_bundle import gapic_version as package_version\n", + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "\'BundledQuery\',", + "\"BundledQuery\",\n\"FirestoreBundle\",",) + + s.move( + library / f"google/cloud/bundle", + f"google/cloud/firestore_bundle", + excludes=["noxfile.py"], + ) + s.move(library / f"tests", f"tests") + +s.remove_staging_dirs() + +# ---------------------------------------------------------------------------- +# Add templated files +# ---------------------------------------------------------------------------- +templated_files = common.py_library( + samples=False, # set to True only if there are samples + unit_test_external_dependencies=["aiounittest", "six", "freezegun"], + system_test_external_dependencies=["pytest-asyncio", "six"], + microgenerator=True, + cov_level=100, + split_system_tests=True, + default_python_version="3.14", + system_test_python_versions=["3.14"], + unit_test_python_versions=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"], +) + +s.move(templated_files, + excludes=[".github/**", ".kokoro/**", "renovate.json"]) + +python.py_samples(skip_readmes=True) + +s.shell.run(["nox", "-s", "blacken"], hide_output=False) diff --git a/mypy.ini b/mypy.ini index 4505b48543..59a6e4d37a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,7 @@ [mypy] -python_version = 3.6 +python_version = 3.13 namespace_packages = True + +# ignore gapic files +[mypy-google.cloud.firestore_v1.services.*] +ignore_errors = True \ No newline at end of file diff --git a/noxfile.py b/noxfile.py index 7ef3ed5b88..3e8b807701 100644 --- a/noxfile.py +++ b/noxfile.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# DO NOT EDIT THIS FILE OUTSIDE OF `.librarian/generator-input` +# The source of truth for this file is `.librarian/generator-input` + + # Generated by synthtool. DO NOT EDIT! from __future__ import absolute_import @@ -33,7 +37,7 @@ ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.8" +DEFAULT_PYTHON_VERSION = "3.14" UNIT_TEST_PYTHON_VERSIONS: List[str] = [ "3.7", @@ -43,6 +47,7 @@ "3.11", "3.12", "3.13", + "3.14", ] UNIT_TEST_STANDARD_DEPENDENCIES = [ "mock", @@ -61,7 +66,7 @@ UNIT_TEST_EXTRAS: List[str] = [] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} -SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.7"] +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.14"] SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ "mock", "pytest", @@ -79,7 +84,12 @@ CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() nox.options.sessions = [ - "unit", + "unit-3.9", + "unit-3.10", + "unit-3.11", + "unit-3.12", + "unit-3.13", + "unit-3.14", "system_emulated", "system", "mypy", @@ -155,15 +165,22 @@ def pytype(session): def mypy(session): """Verify type hints are mypy compatible.""" session.install("-e", ".") - session.install("mypy", "types-setuptools") - # TODO: also verify types on tests, all of google package - session.run("mypy", "-p", "google.cloud.firestore", "--no-incremental") + session.install("mypy", "types-setuptools", "types-protobuf") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1", + "--no-incremental", + "--check-untyped-defs", + "--exclude", + "services", + ) @nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" - session.install("docutils", "pygments") + session.install("setuptools", "docutils", "pygments") session.run("python", "setup.py", "check", "--restructuredtext", "--strict") @@ -203,7 +220,8 @@ def install_unittest_dependencies(session, *constraints): def unit(session, protobuf_implementation): # Install all test dependencies, then install this package in-place. - if protobuf_implementation == "cpp" and session.python in ("3.11", "3.12", "3.13"): + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): session.skip("cpp implementation is not supported in python 3.11+") constraints_path = str( @@ -368,7 +386,13 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=100") + session.run( + "coverage", + "report", + "--show-missing", + "--fail-under=100", + "--omit=tests/*", + ) session.run("coverage", "erase") @@ -454,7 +478,7 @@ def docfx(session): ) -@nox.session(python="3.13") +@nox.session(python=DEFAULT_PYTHON_VERSION) @nox.parametrize( "protobuf_implementation", ["python", "upb", "cpp"], @@ -462,7 +486,8 @@ def docfx(session): def prerelease_deps(session, protobuf_implementation): """Run all tests with prerelease versions of dependencies installed.""" - if protobuf_implementation == "cpp" and session.python in ("3.11", "3.12", "3.13"): + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): session.skip("cpp implementation is not supported in python 3.11+") # Install all dependencies diff --git a/pytest.ini b/pytest.ini index 099cbd3ad2..308d1b494d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -18,3 +18,9 @@ filterwarnings = ignore:After January 1, 2024, new releases of this library will drop support for Python 3.7:DeprecationWarning # Remove warning once https://github.com/googleapis/gapic-generator-python/issues/1939 is fixed ignore:get_mtls_endpoint_and_cert_source is deprecated.:DeprecationWarning + # Remove once credential file support is removed + ignore:.*The \`credentials_file\` argument is deprecated.*:DeprecationWarning + # Remove after updating test dependencies that use asyncio.iscoroutinefunction + ignore:.*\'asyncio.iscoroutinefunction\' is deprecated.*:DeprecationWarning + ignore:.*\'asyncio.get_event_loop_policy\' is deprecated.*:DeprecationWarning + ignore:.*Please upgrade to the latest Python version.*:FutureWarning diff --git a/scripts/fixup_firestore_admin_v1_keywords.py b/scripts/fixup_firestore_admin_v1_keywords.py deleted file mode 100644 index f4672d2da5..0000000000 --- a/scripts/fixup_firestore_admin_v1_keywords.py +++ /dev/null @@ -1,212 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import argparse -import os -try: - import libcst as cst -except ImportError: - raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') - - - -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class firestore_adminCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'bulk_delete_documents': ('name', 'collection_ids', 'namespace_ids', ), - 'create_backup_schedule': ('parent', 'backup_schedule', ), - 'create_database': ('parent', 'database', 'database_id', ), - 'create_index': ('parent', 'index', ), - 'create_user_creds': ('parent', 'user_creds', 'user_creds_id', ), - 'delete_backup': ('name', ), - 'delete_backup_schedule': ('name', ), - 'delete_database': ('name', 'etag', ), - 'delete_index': ('name', ), - 'delete_user_creds': ('name', ), - 'disable_user_creds': ('name', ), - 'enable_user_creds': ('name', ), - 'export_documents': ('name', 'collection_ids', 'output_uri_prefix', 'namespace_ids', 'snapshot_time', ), - 'get_backup': ('name', ), - 'get_backup_schedule': ('name', ), - 'get_database': ('name', ), - 'get_field': ('name', ), - 'get_index': ('name', ), - 'get_user_creds': ('name', ), - 'import_documents': ('name', 'collection_ids', 'input_uri_prefix', 'namespace_ids', ), - 'list_backups': ('parent', 'filter', ), - 'list_backup_schedules': ('parent', ), - 'list_databases': ('parent', 'show_deleted', ), - 'list_fields': ('parent', 'filter', 'page_size', 'page_token', ), - 'list_indexes': ('parent', 'filter', 'page_size', 'page_token', ), - 'list_user_creds': ('parent', ), - 'reset_user_password': ('name', ), - 'restore_database': ('parent', 'database_id', 'backup', 'encryption_config', ), - 'update_backup_schedule': ('backup_schedule', 'update_mask', ), - 'update_database': ('database', 'update_mask', ), - 'update_field': ('field', 'update_mask', ), - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: a.keyword.value not in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), -cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=firestore_adminCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the firestore_admin client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/scripts/fixup_firestore_v1_keywords.py b/scripts/fixup_firestore_v1_keywords.py deleted file mode 100644 index 6481e76bb7..0000000000 --- a/scripts/fixup_firestore_v1_keywords.py +++ /dev/null @@ -1,197 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import argparse -import os -try: - import libcst as cst -except ImportError: - raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') - - - -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class firestoreCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'batch_get_documents': ('database', 'documents', 'mask', 'transaction', 'new_transaction', 'read_time', ), - 'batch_write': ('database', 'writes', 'labels', ), - 'begin_transaction': ('database', 'options', ), - 'commit': ('database', 'writes', 'transaction', ), - 'create_document': ('parent', 'collection_id', 'document', 'document_id', 'mask', ), - 'delete_document': ('name', 'current_document', ), - 'get_document': ('name', 'mask', 'transaction', 'read_time', ), - 'list_collection_ids': ('parent', 'page_size', 'page_token', 'read_time', ), - 'list_documents': ('parent', 'collection_id', 'page_size', 'page_token', 'order_by', 'mask', 'transaction', 'read_time', 'show_missing', ), - 'listen': ('database', 'add_target', 'remove_target', 'labels', ), - 'partition_query': ('parent', 'structured_query', 'partition_count', 'page_token', 'page_size', 'read_time', ), - 'rollback': ('database', 'transaction', ), - 'run_aggregation_query': ('parent', 'structured_aggregation_query', 'transaction', 'new_transaction', 'read_time', 'explain_options', ), - 'run_query': ('parent', 'structured_query', 'transaction', 'new_transaction', 'read_time', 'explain_options', ), - 'update_document': ('document', 'update_mask', 'mask', 'current_document', ), - 'write': ('database', 'stream_id', 'writes', 'stream_token', 'labels', ), - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: a.keyword.value not in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), -cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=firestoreCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the firestore client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/setup.py b/setup.py index 2a47080a15..72a6f53bd8 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# DO NOT EDIT THIS FILE OUTSIDE OF `.librarian/generator-input` +# The source of truth for this file is `.librarian/generator-input` + + import io import os @@ -79,6 +83,8 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Topic :: Internet", "Topic :: Software Development :: Libraries :: Python Modules", @@ -88,10 +94,6 @@ install_requires=dependencies, extras_require=extras, python_requires=">=3.7", - scripts=[ - "scripts/fixup_firestore_v1_keywords.py", - "scripts/fixup_firestore_admin_v1_keywords.py", - ], include_package_data=True, zip_safe=False, ) diff --git a/tests/unit/gapic/v1/__init__.py b/testing/constraints-3.14.txt similarity index 100% rename from tests/unit/gapic/v1/__init__.py rename to testing/constraints-3.14.txt diff --git a/tests/system/test_system.py b/tests/system/test_system.py index b96ed04715..c66340de1e 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -217,6 +217,39 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert len(execution_stats.debug_stats) > 0 +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_collections_w_read_time(client, cleanup, database): + first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + first_document_id = "doc" + UNIQUE_RESOURCE_ID + first_document = client.document(first_collection_id, first_document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(first_document.delete) + + data = {"status": "new"} + write_result = first_document.create(data) + read_time = write_result.update_time + num_collections = len(list(client.collections())) + + second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" + second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" + second_document = client.document(second_collection_id, second_document_id) + cleanup(second_document.delete) + second_document.create(data) + + # Test that listing current collections does have the second id. + curr_collections = list(client.collections()) + assert len(curr_collections) > num_collections + ids = [collection.id for collection in curr_collections] + assert second_collection_id in ids + assert first_collection_id in ids + + # We're just testing that we added one collection at read_time, not two. + collections = list(client.collections(read_time=read_time)) + ids = [collection.id for collection in collections] + assert second_collection_id not in ids + assert first_collection_id in ids + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -708,6 +741,42 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_document_collections_w_read_time(client, cleanup, database): + collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID + document_id = "doc" + UNIQUE_RESOURCE_ID + document = client.document(collection_id, document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + data = {"now": firestore.SERVER_TIMESTAMP} + document.create(data) + + original_child_ids = ["child1", "child2"] + read_time = None + + for child_id in original_child_ids: + subcollection = document.collection(child_id) + update_time, subdoc = subcollection.add({"foo": "bar"}) + read_time = ( + update_time if read_time is None or update_time > read_time else read_time + ) + cleanup(subdoc.delete) + + update_time, newdoc = document.collection("child3").add({"foo": "bar"}) + cleanup(newdoc.delete) + assert update_time > read_time + + # Compare the query at read_time to the query at new update time. + original_children = document.collections(read_time=read_time) + assert sorted(child.id for child in original_children) == sorted(original_child_ids) + + original_children = document.collections() + assert sorted(child.id for child in original_children) == sorted( + original_child_ids + ["child3"] + ) + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID @@ -1072,6 +1141,31 @@ def test_collection_add(client, cleanup, database): assert set(collection3.list_documents()) == {document_ref5} +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_list_collections_with_read_time(client, cleanup, database): + # TODO(microgen): list_documents is returning a generator, not a list. + # Consider if this is desired. Also, Document isn't hashable. + collection_id = "coll-add" + UNIQUE_RESOURCE_ID + collection = client.collection(collection_id) + + assert set(collection.list_documents()) == set() + + data1 = {"foo": "bar"} + update_time1, document_ref1 = collection.add(data1) + cleanup(document_ref1.delete) + assert set(collection.list_documents()) == {document_ref1} + + data2 = {"bar": "baz"} + update_time2, document_ref2 = collection.add(data2) + cleanup(document_ref2.delete) + assert set(collection.list_documents()) == {document_ref1, document_ref2} + assert set(collection.list_documents(read_time=update_time1)) == {document_ref1} + assert set(collection.list_documents(read_time=update_time2)) == { + document_ref1, + document_ref2, + } + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_unicode_doc(client, cleanup, database): collection_id = "coll-unicode" + UNIQUE_RESOURCE_ID @@ -1477,6 +1571,44 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_query_stream_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find the most recent read_time in collections + read_time = max(docref.get().read_time for docref in collection.list_documents()) + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + # Compare query at read_time to query at current time. + query = collection.where(filter=FieldFilter("b", "==", 1)) + values = { + snapshot.id: snapshot.to_dict() + for snapshot in query.stream(read_time=read_time) + } + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref + + new_values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client @@ -1787,6 +1919,7 @@ def test_get_all(client, cleanup, database): document3 = client.document(collection_name, "c") # Add to clean-up before API requests (in case ``create()`` fails). cleanup(document1.delete) + cleanup(document2.delete) cleanup(document3.delete) data1 = {"a": {"b": 2, "c": 3}, "d": 4, "e": 0} @@ -1794,6 +1927,8 @@ def test_get_all(client, cleanup, database): data3 = {"a": {"b": 5, "c": 6}, "d": 7, "e": 100} write_result3 = document3.create(data3) + read_time = write_result3.update_time + # 0. Get 3 unique documents, one of which is missing. snapshots = list(client.get_all([document1, document2, document3])) @@ -1829,6 +1964,27 @@ def test_get_all(client, cleanup, database): restricted3 = {"a": {"b": data3["a"]["b"]}, "d": data3["d"]} check_snapshot(snapshot3, document3, restricted3, write_result3) + # 3. Use ``read_time`` in ``get_all`` + new_data = {"a": {"b": 8, "c": 9}, "d": 10, "e": 1010} + document1.update(new_data) + document2.create(new_data) + document3.update(new_data) + + snapshots = list( + client.get_all([document1, document2, document3], read_time=read_time) + ) + assert snapshots[0].exists + assert snapshots[1].exists + assert not snapshots[2].exists + + snapshots = [snapshot for snapshot in snapshots if snapshot.exists] + id_attr = operator.attrgetter("id") + snapshots.sort(key=id_attr) + + snapshot1, snapshot3 = snapshots + check_snapshot(snapshot1, document1, data1, write_result1) + check_snapshot(snapshot3, document3, data3, write_result3) + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_batch(client, cleanup, database): @@ -3042,6 +3198,48 @@ def test_query_with_or_composite_filter(collection, database): assert lt_10 > 0 +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] +) +def test_aggregation_queries_with_read_time( + collection, query, cleanup, database, aggregation_type, expected_value +): + """ + Ensure that all aggregation queries work when read_time is passed into + a query..().get() method + """ + # Find the most recent read_time in collections + read_time = max(docref.get().read_time for docref in collection.list_documents()) + document_data = { + "a": 1, + "b": 9000, + "c": [1, 123123123], + "stats": {"sum": 9001, "product": 9000}, + } + + _, doc_ref = collection.add(document_data) + cleanup(doc_ref.delete) + + if aggregation_type == "count": + aggregation_query = query.count() + elif aggregation_type == "sum": + aggregation_query = collection.sum("stats.product") + elif aggregation_type == "avg": + aggregation_query = collection.avg("stats.product") + + # Check that adding the new document data affected the results of the aggregation queries. + new_result = aggregation_query.get() + assert len(new_result) == 1 + for r in new_result[0]: + assert r.value != expected_value + + old_result = aggregation_query.get(read_time=read_time) + assert len(old_result) == 1 + for r in old_result[0]: + assert r.value == expected_value + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) @@ -3255,6 +3453,52 @@ def in_transaction(transaction): assert inner_fn_ran is True +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_query_in_transaction_with_read_time(client, cleanup, database): + """ + Test query profiling in transactions. + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(5)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + doc_refs[0].create({"a": 1, "b": 2}) + doc_refs[1].create({"a": 1, "b": 1}) + + read_time = max(docref.get().read_time for docref in doc_refs) + doc_refs[2].create({"a": 1, "b": 3}) + + collection = client.collection(collection_id) + query = collection.where(filter=FieldFilter("a", "==", 1)) + + with client.transaction() as transaction: + # should work when transaction is initiated through transactional decorator + @firestore.transactional + def in_transaction(transaction): + global inner_fn_ran + + new_b_values = [ + docs.get("b") for docs in transaction.get(query, read_time=read_time) + ] + assert len(new_b_values) == 2 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 not in new_b_values + + new_b_values = [docs.get("b") for docs in transaction.get(query)] + assert len(new_b_values) == 3 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 in new_b_values + + inner_fn_ran = True + + in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_update_w_uuid(client, cleanup, database): """ diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 200be7d8ab..945e7cb128 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -234,6 +234,41 @@ async def test_create_document(client, cleanup, database): assert stored_data == expected_data +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_collections_w_read_time(client, cleanup, database): + first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + first_document_id = "doc" + UNIQUE_RESOURCE_ID + first_document = client.document(first_collection_id, first_document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(first_document.delete) + + data = {"status": "new"} + write_result = await first_document.create(data) + read_time = write_result.update_time + num_collections = len([x async for x in client.collections(retry=RETRIES)]) + + second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" + second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" + second_document = client.document(second_collection_id, second_document_id) + cleanup(second_document.delete) + await second_document.create(data) + + # Test that listing current collections does have the second id. + curr_collections = [x async for x in client.collections(retry=RETRIES)] + assert len(curr_collections) > num_collections + ids = [collection.id for collection in curr_collections] + assert second_collection_id in ids + assert first_collection_id in ids + + # We're just testing that we added one collection at read_time, not two. + collections = [ + x async for x in client.collections(retry=RETRIES, read_time=read_time) + ] + ids = [collection.id for collection in collections] + assert second_collection_id not in ids + assert first_collection_id in ids + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID @@ -260,6 +295,42 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_document_collections_w_read_time(client, cleanup, database): + collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID + document_id = "doc" + UNIQUE_RESOURCE_ID + document = client.document(collection_id, document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document.delete) + + data = {"now": firestore.SERVER_TIMESTAMP} + await document.create(data) + + original_child_ids = ["child1", "child2"] + read_time = None + + for child_id in original_child_ids: + subcollection = document.collection(child_id) + update_time, subdoc = await subcollection.add({"foo": "bar"}) + read_time = ( + update_time if read_time is None or update_time > read_time else read_time + ) + cleanup(subdoc.delete) + + update_time, newdoc = await document.collection("child3").add({"foo": "bar"}) + cleanup(newdoc.delete) + assert update_time > read_time + + # Compare the query at read_time to the query at new update time. + original_children = [doc async for doc in document.collections(read_time=read_time)] + assert sorted(child.id for child in original_children) == sorted(original_child_ids) + + original_children = [doc async for doc in document.collections()] + assert sorted(child.id for child in original_children) == sorted( + original_child_ids + ["child3"] + ) + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID @@ -1062,6 +1133,38 @@ async def test_collection_add(client, cleanup, database): assert set([i async for i in collection3.list_documents()]) == {document_ref5} +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_list_collections_with_read_time(client, cleanup, database): + # TODO(microgen): list_documents is returning a generator, not a list. + # Consider if this is desired. Also, Document isn't hashable. + collection_id = "coll-add" + UNIQUE_RESOURCE_ID + collection = client.collection(collection_id) + + assert set([i async for i in collection.list_documents()]) == set() + + data1 = {"foo": "bar"} + update_time1, document_ref1 = await collection.add(data1) + cleanup(document_ref1.delete) + assert set([i async for i in collection.list_documents()]) == {document_ref1} + + data2 = {"bar": "baz"} + update_time2, document_ref2 = await collection.add(data2) + cleanup(document_ref2.delete) + assert set([i async for i in collection.list_documents()]) == { + document_ref1, + document_ref2, + } + assert set( + [i async for i in collection.list_documents(read_time=update_time1)] + ) == {document_ref1} + assert set( + [i async for i in collection.list_documents(read_time=update_time2)] + ) == { + document_ref1, + document_ref2, + } + + @pytest_asyncio.fixture async def query_docs(client): collection_id = "qs" + UNIQUE_RESOURCE_ID @@ -1389,6 +1492,46 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_query_stream_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + # Compare query at read_time to query at current time. + query = collection.where(filter=FieldFilter("b", "==", 1)) + values = { + snapshot.id: snapshot.to_dict() + async for snapshot in query.stream(read_time=read_time) + } + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref + + new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client @@ -1853,6 +1996,8 @@ async def test_get_all(client, cleanup, database): data3 = {"a": {"b": 5, "c": 6}, "d": 7, "e": 100} write_result3 = await document3.create(data3) + read_time = write_result3.update_time + # 0. Get 3 unique documents, one of which is missing. snapshots = [i async for i in client.get_all([document1, document2, document3])] @@ -1891,6 +2036,22 @@ async def test_get_all(client, cleanup, database): restricted3 = {"a": {"b": data3["a"]["b"]}, "d": data3["d"]} check_snapshot(snapshot3, document3, restricted3, write_result3) + # 3. Use ``read_time`` in ``get_all`` + new_data = {"a": {"b": 8, "c": 9}, "d": 10, "e": 1010} + await document1.update(new_data) + await document2.create(new_data) + await document3.update(new_data) + + snapshots = [ + i + async for i in client.get_all( + [document1, document2, document3], read_time=read_time + ) + ] + assert snapshots[0].exists + assert snapshots[1].exists + assert not snapshots[2].exists + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_live_bulk_writer(client, cleanup, database): @@ -2765,6 +2926,50 @@ async def test_async_avg_query_stream_w_explain_options_analyze_false( explain_metrics.execution_stats +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] +) +async def test_aggregation_queries_with_read_time( + collection, async_query, cleanup, database, aggregation_type, expected_value +): + """ + Ensure that all aggregation queries work when read_time is passed into + a query..().get() method + """ + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + document_data = { + "a": 1, + "b": 9000, + "c": [1, 123123123], + "stats": {"sum": 9001, "product": 9000}, + } + + _, doc_ref = await collection.add(document_data) + cleanup(doc_ref.delete) + + if aggregation_type == "count": + aggregation_query = async_query.count() + elif aggregation_type == "sum": + aggregation_query = collection.sum("stats.product") + elif aggregation_type == "avg": + aggregation_query = collection.avg("stats.product") + + # Check that adding the new document data affected the results of the aggregation queries. + new_result = await aggregation_query.get() + assert len(new_result) == 1 + for r in new_result[0]: + assert r.value != expected_value + + old_result = await aggregation_query.get(read_time=read_time) + assert len(old_result) == 1 + for r in old_result[0]: + assert r.value == expected_value + + @firestore.async_transactional async def create_in_transaction_helper( transaction, client, collection_id, cleanup, database @@ -3176,3 +3381,53 @@ async def in_transaction(transaction): # make sure we didn't skip assertions in inner function assert inner_fn_ran is True + + +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_query_in_transaction_with_read_time(client, cleanup, database): + """ + Test query profiling in transactions. + """ + collection_id = "doc-create" + UNIQUE_RESOURCE_ID + doc_ids = [f"doc{i}" + UNIQUE_RESOURCE_ID for i in range(5)] + doc_refs = [client.document(collection_id, doc_id) for doc_id in doc_ids] + for doc_ref in doc_refs: + cleanup(doc_ref.delete) + await doc_refs[0].create({"a": 1, "b": 2}) + await doc_refs[1].create({"a": 1, "b": 1}) + + read_time = max([(await docref.get()).read_time for docref in doc_refs]) + await doc_refs[2].create({"a": 1, "b": 3}) + + collection = client.collection(collection_id) + query = collection.where(filter=FieldFilter("a", "==", 1)) + + # should work when transaction is initiated through transactional decorator + async with client.transaction() as transaction: + + @firestore.async_transactional + async def in_transaction(transaction): + global inner_fn_ran + + new_b_values = [ + docs.get("b") + async for docs in await transaction.get(query, read_time=read_time) + ] + assert len(new_b_values) == 2 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 not in new_b_values + + new_b_values = [ + docs.get("b") async for docs in await transaction.get(query) + ] + assert len(new_b_values) == 3 + assert 1 in new_b_values + assert 2 in new_b_values + assert 3 in new_b_values + + inner_fn_ran = True + + await in_transaction(transaction) + # make sure we didn't skip assertions in inner function + assert inner_fn_ran is True diff --git a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py index 421f45a70e..7ef138d151 100644 --- a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py +++ b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py @@ -75,6 +75,7 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.firestore_admin_v1.types import schedule +from google.cloud.firestore_admin_v1.types import snapshot from google.cloud.firestore_admin_v1.types import user_creds from google.cloud.firestore_admin_v1.types import user_creds as gfa_user_creds from google.cloud.location import locations_pb2 @@ -185,12 +186,19 @@ def test__read_environment_variables(): with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError) as excinfo: - FirestoreAdminClient._read_environment_variables() - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with pytest.raises(ValueError) as excinfo: + FirestoreAdminClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + else: + assert FirestoreAdminClient._read_environment_variables() == ( + False, + "auto", + None, + ) with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): assert FirestoreAdminClient._read_environment_variables() == ( @@ -229,6 +237,105 @@ def test__read_environment_variables(): ) +def test_use_client_cert_effective(): + # Test case 1: Test when `should_use_client_cert` returns True. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=True + ): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 2: Test when `should_use_client_cert` returns False. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should NOT be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=False + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 3: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 4: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 5: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "True". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "True"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 6: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "False". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "False"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 7: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "TRUE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "TRUE"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 8: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "FALSE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "FALSE"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 9: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not set. + # In this case, the method should return False, which is the default value. + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, clear=True): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 10: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should raise a ValueError as the environment variable must be either + # "true" or "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + with pytest.raises(ValueError): + FirestoreAdminClient._use_client_cert_effective() + + # Test case 11: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should return False as the environment variable is set to an invalid value. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 12: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is unset. Also, + # the GOOGLE_API_CONFIG environment variable is unset. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": ""}): + with mock.patch.dict(os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": ""}): + assert FirestoreAdminClient._use_client_cert_effective() is False + + def test__get_client_cert_source(): mock_provided_cert_source = mock.Mock() mock_default_cert_source = mock.Mock() @@ -594,17 +701,6 @@ def test_firestore_admin_client_client_options( == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client = client_class(transport=transport_name) - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: @@ -820,6 +916,119 @@ def test_firestore_admin_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == mock_api_endpoint assert cert_source is None + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "Unsupported". + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", None) + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset(empty). + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() @@ -870,18 +1079,6 @@ def test_firestore_admin_client_get_mtls_endpoint_and_cert_source(client_class): == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client_class.get_mtls_endpoint_and_cert_source() - - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - @pytest.mark.parametrize( "client_class", [FirestoreAdminClient, FirestoreAdminAsyncClient] @@ -11859,6 +12056,192 @@ async def test_delete_backup_schedule_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + firestore_admin.CloneDatabaseRequest, + dict, + ], +) +def test_clone_database(request_type, transport: str = "grpc"): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore_admin.CloneDatabaseRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_clone_database_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore_admin.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.clone_database(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore_admin.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + ) + + +def test_clone_database_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.clone_database in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.clone_database] = mock_rpc + request = {} + client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_clone_database_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.clone_database + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.clone_database + ] = mock_rpc + + request = {} + await client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_clone_database_async( + transport: str = "grpc_asyncio", request_type=firestore_admin.CloneDatabaseRequest +): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore_admin.CloneDatabaseRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_clone_database_async_from_dict(): + await test_clone_database_async(request_type=dict) + + def test_create_index_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call @@ -17625,31 +18008,166 @@ def test_delete_backup_schedule_rest_flattened_error(transport: str = "rest"): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): +def test_clone_database_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: client = FirestoreAdminClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + transport="rest", ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = FirestoreAdminClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.clone_database in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[client._transport.clone_database] = mock_rpc - # It is an error to provide an api_key and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) + request = {} + client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_clone_database_rest_required_fields( + request_type=firestore_admin.CloneDatabaseRequest, +): + transport_class = transports.FirestoreAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["database_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).clone_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + jsonified_request["databaseId"] = "database_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).clone_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "databaseId" in jsonified_request + assert jsonified_request["databaseId"] == "database_id_value" + + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.clone_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_clone_database_rest_unset_required_fields(): + transport = transports.FirestoreAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.clone_database._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "databaseId", + "pitrSnapshot", + ) + ) + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FirestoreAdminClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): @@ -18404,6 +18922,91 @@ def test_delete_backup_schedule_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_clone_database_empty_call_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +def test_clone_database_routing_parameters_request_1_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_clone_database_routing_parameters_request_2_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAdminAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -19270,6 +19873,103 @@ async def test_delete_backup_schedule_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_clone_database_empty_call_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +@pytest.mark.asyncio +async def test_clone_database_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_clone_database_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreAdminClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -20792,6 +21492,7 @@ def test_create_database_rest_call_success(request_type): "backup": {"backup": "backup_value"}, "operation": "operation_value", }, + "tags": {}, "free_tier": True, "etag": "etag_value", "database_edition": 1, @@ -21303,6 +22004,7 @@ def test_update_database_rest_call_success(request_type): "backup": {"backup": "backup_value"}, "operation": "operation_value", }, + "tags": {}, "free_tier": True, "etag": "etag_value", "database_edition": 1, @@ -23837,6 +24539,129 @@ def test_delete_backup_schedule_rest_interceptors(null_interceptor): pre.assert_called_once() +def test_clone_database_rest_bad_request( + request_type=firestore_admin.CloneDatabaseRequest, +): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.clone_database(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore_admin.CloneDatabaseRequest, + dict, + ], +) +def test_clone_database_rest_call_success(request_type): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.clone_database(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_clone_database_rest_interceptors(null_interceptor): + transport = transports.FirestoreAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FirestoreAdminRestInterceptor(), + ) + client = FirestoreAdminClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.FirestoreAdminRestInterceptor, "post_clone_database" + ) as post, mock.patch.object( + transports.FirestoreAdminRestInterceptor, "post_clone_database_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreAdminRestInterceptor, "pre_clone_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore_admin.CloneDatabaseRequest.pb( + firestore_admin.CloneDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.content = return_value + + request = firestore_admin.CloneDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata + + client.clone_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_cancel_operation_rest_bad_request( request_type=operations_pb2.CancelOperationRequest, ): @@ -24734,6 +25559,88 @@ def test_delete_backup_schedule_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_clone_database_empty_call_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +def test_clone_database_routing_parameters_request_1_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_clone_database_routing_parameters_request_2_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_firestore_admin_rest_lro_client(): client = FirestoreAdminClient( credentials=ga_credentials.AnonymousCredentials(), @@ -24815,6 +25722,7 @@ def test_firestore_admin_base_transport(): "list_backup_schedules", "update_backup_schedule", "delete_backup_schedule", + "clone_database", "get_operation", "cancel_operation", "delete_operation", @@ -25187,6 +26095,9 @@ def test_firestore_admin_client_transport_session_collision(transport_name): session1 = client1.transport.delete_backup_schedule._session session2 = client2.transport.delete_backup_schedule._session assert session1 != session2 + session1 = client1.transport.clone_database._session + session2 = client2.transport.clone_database._session + assert session1 != session2 def test_firestore_admin_grpc_transport_channel(): @@ -25217,6 +26128,7 @@ def test_firestore_admin_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( "transport_class", [ diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index eac609cab4..e3821e772d 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -61,7 +61,9 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write as gf_write @@ -160,12 +162,19 @@ def test__read_environment_variables(): with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError) as excinfo: - FirestoreClient._read_environment_variables() - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with pytest.raises(ValueError) as excinfo: + FirestoreClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + else: + assert FirestoreClient._read_environment_variables() == ( + False, + "auto", + None, + ) with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): assert FirestoreClient._read_environment_variables() == (False, "never", None) @@ -192,6 +201,105 @@ def test__read_environment_variables(): ) +def test_use_client_cert_effective(): + # Test case 1: Test when `should_use_client_cert` returns True. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=True + ): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 2: Test when `should_use_client_cert` returns False. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should NOT be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=False + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 3: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 4: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 5: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "True". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "True"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 6: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "False". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "False"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 7: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "TRUE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "TRUE"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 8: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "FALSE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "FALSE"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 9: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not set. + # In this case, the method should return False, which is the default value. + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, clear=True): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 10: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should raise a ValueError as the environment variable must be either + # "true" or "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + with pytest.raises(ValueError): + FirestoreClient._use_client_cert_effective() + + # Test case 11: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should return False as the environment variable is set to an invalid value. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 12: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is unset. Also, + # the GOOGLE_API_CONFIG environment variable is unset. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": ""}): + with mock.patch.dict(os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": ""}): + assert FirestoreClient._use_client_cert_effective() is False + + def test__get_client_cert_source(): mock_provided_cert_source = mock.Mock() mock_default_cert_source = mock.Mock() @@ -555,17 +663,6 @@ def test_firestore_client_client_options(client_class, transport_class, transpor == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client = client_class(transport=transport_name) - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: @@ -777,6 +874,119 @@ def test_firestore_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == mock_api_endpoint assert cert_source is None + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "Unsupported". + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", None) + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset(empty). + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() @@ -827,18 +1037,6 @@ def test_firestore_client_get_mtls_endpoint_and_cert_source(client_class): == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client_class.get_mtls_endpoint_and_cert_source() - - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - @pytest.mark.parametrize("client_class", [FirestoreClient, FirestoreAsyncClient]) @mock.patch.object( @@ -3884,6 +4082,185 @@ async def test_run_query_field_headers_async(): ) in kw["metadata"] +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline(request_type, transport: str = "grpc"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([firestore.ExecutePipelineResponse()]) + response = client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, firestore.ExecutePipelineResponse) + + +def test_execute_pipeline_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore.ExecutePipelineRequest( + database="database_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.execute_pipeline(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore.ExecutePipelineRequest( + database="database_value", + ) + + +def test_execute_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_pipeline in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_pipeline + ] = mock_rpc + request = {} + client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.execute_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.execute_pipeline + ] = mock_rpc + + request = {} + await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async( + transport: str = "grpc_asyncio", request_type=firestore.ExecutePipelineRequest +): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + response = await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, firestore.ExecutePipelineResponse) + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_from_dict(): + await test_execute_pipeline_async(request_type=dict) + + @pytest.mark.parametrize( "request_type", [ @@ -7410,7 +7787,7 @@ def test_run_query_rest_unset_required_fields(): assert set(unset_fields) == (set(()) & set(("parent",))) -def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): +def test_execute_pipeline_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7424,10 +7801,7 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert ( - client._transport.run_aggregation_query - in client._transport._wrapped_methods - ) + assert client._transport.execute_pipeline in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -7435,29 +7809,29 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.run_aggregation_query + client._transport.execute_pipeline ] = mock_rpc request = {} - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_run_aggregation_query_rest_required_fields( - request_type=firestore.RunAggregationQueryRequest, +def test_execute_pipeline_rest_required_fields( + request_type=firestore.ExecutePipelineRequest, ): transport_class = transports.FirestoreRestTransport request_init = {} - request_init["parent"] = "" + request_init["database"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7468,21 +7842,21 @@ def test_run_aggregation_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["database"] = "database_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" client = FirestoreClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7491,7 +7865,7 @@ def test_run_aggregation_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.RunAggregationQueryResponse() + return_value = firestore.ExecutePipelineResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7513,7 +7887,7 @@ def test_run_aggregation_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.RunAggregationQueryResponse.pb(return_value) + return_value = firestore.ExecutePipelineResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -7523,23 +7897,23 @@ def test_run_aggregation_query_rest_required_fields( with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.run_aggregation_query(request) + response = client.execute_pipeline(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_run_aggregation_query_rest_unset_required_fields(): +def test_execute_pipeline_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("parent",))) + unset_fields = transport.execute_pipeline._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) -def test_partition_query_rest_use_cached_wrapped_rpc(): +def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7553,30 +7927,35 @@ def test_partition_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.partition_query in client._transport._wrapped_methods + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc request = {} - client.partition_query(request) + client.run_aggregation_query(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.partition_query(request) + client.run_aggregation_query(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_partition_query_rest_required_fields( - request_type=firestore.PartitionQueryRequest, +def test_run_aggregation_query_rest_required_fields( + request_type=firestore.RunAggregationQueryRequest, ): transport_class = transports.FirestoreRestTransport @@ -7592,7 +7971,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7601,7 +7980,131 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = firestore.RunAggregationQueryResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.RunAggregationQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.run_aggregation_query(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_run_aggregation_query_rest_unset_required_fields(): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("parent",))) + + +def test_partition_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.partition_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + + request = {} + client.partition_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.partition_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_partition_query_rest_required_fields( + request_type=firestore.PartitionQueryRequest, +): + transport_class = transports.FirestoreRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -8553,6 +9056,27 @@ def test_run_query_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_grpc(): @@ -8662,6 +9186,60 @@ def test_create_document_empty_call_grpc(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -8911,6 +9489,32 @@ async def test_run_query_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_execute_pipeline_empty_call_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @pytest.mark.asyncio @@ -9048,6 +9652,70 @@ async def test_create_document_empty_call_grpc_asyncio(): assert args[0] == request_msg +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -10233,6 +10901,137 @@ def test_run_query_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() +def test_execute_pipeline_rest_bad_request( + request_type=firestore.ExecutePipelineRequest, +): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.execute_pipeline(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline_rest_call_success(request_type): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = firestore.ExecutePipelineResponse( + transaction=b"transaction_blob", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.ExecutePipelineResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value.iter_content = mock.Mock(return_value=iter(json_return_value)) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.execute_pipeline(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, firestore.ExecutePipelineResponse) + assert response.transaction == b"transaction_blob" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_execute_pipeline_rest_interceptors(null_interceptor): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.FirestoreRestInterceptor(), + ) + client = FirestoreClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline" + ) as post, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreRestInterceptor, "pre_execute_pipeline" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore.ExecutePipelineRequest.pb( + firestore.ExecutePipelineRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = firestore.ExecutePipelineResponse.to_json( + firestore.ExecutePipelineResponse() + ) + req.return_value.iter_content = mock.Mock(return_value=iter(return_value)) + + request = firestore.ExecutePipelineRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = firestore.ExecutePipelineResponse() + post_with_metadata.return_value = firestore.ExecutePipelineResponse(), metadata + + client.execute_pipeline( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_run_aggregation_query_rest_bad_request( request_type=firestore.RunAggregationQueryRequest, ): @@ -11409,6 +12208,26 @@ def test_run_query_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_rest(): @@ -11513,6 +12332,58 @@ def test_create_document_empty_call_rest(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = FirestoreClient( @@ -11555,6 +12426,7 @@ def test_firestore_base_transport(): "commit", "rollback", "run_query", + "execute_pipeline", "run_aggregation_query", "partition_query", "write", @@ -11860,6 +12732,9 @@ def test_firestore_client_transport_session_collision(transport_name): session1 = client1.transport.run_query._session session2 = client2.transport.run_query._session assert session1 != session2 + session1 = client1.transport.execute_pipeline._session + session2 = client2.transport.execute_pipeline._session + assert session1 != session2 session1 = client1.transport.run_aggregation_query._session session2 = client2.transport.run_aggregation_query._session assert session1 != session2 @@ -11911,6 +12786,7 @@ def test_firestore_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( "transport_class", [transports.FirestoreGrpcTransport, transports.FirestoreGrpcAsyncIOTransport], diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 4d1eed1980..69ca69ec78 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -26,6 +26,8 @@ from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator +from google.cloud.firestore_v1.types import RunAggregationQueryResponse +from google.protobuf.timestamp_pb2 import Timestamp from tests.unit.v1._test_helpers import ( make_aggregation_query, make_aggregation_query_response, @@ -49,6 +51,12 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_count_aggregation_no_alias_to_pb(): + count_aggregation = CountAggregation(alias=None) + got_pb = count_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_sum_aggregation_w_field_path(): """ SumAggregation should convert FieldPath inputs into strings @@ -86,6 +94,12 @@ def test_sum_aggregation_to_pb(): assert sum_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_sum_aggregation_no_alias_to_pb(): + sum_aggregation = SumAggregation("someref", alias=None) + got_pb = sum_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_avg_aggregation_to_pb(): from google.cloud.firestore_v1.types import query as query_pb2 @@ -101,6 +115,12 @@ def test_avg_aggregation_to_pb(): assert avg_aggregation._to_protobuf() == expected_aggregation_query_pb +def test_avg_aggregation_no_alias_to_pb(): + avg_aggregation = AvgAggregation("someref", alias=None) + got_pb = avg_aggregation._to_protobuf() + assert got_pb.alias == "" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -384,11 +404,76 @@ def test_aggregation_query_prep_stream_with_explain_options(): assert kwargs == {"retry": None} +def test_aggregation_query_prep_stream_with_read_time(): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") + + # 1800 seconds after epoch + read_time = datetime.now() + + request, kwargs = aggregation_query._prep_stream(read_time=read_time) + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + "read_time": read_time, + } + assert request == expected_request + assert kwargs == {"retry": None} + + +@pytest.mark.parametrize( + "custom_timezone", [None, timezone.utc, timezone(timedelta(hours=5))] +) +def test_aggregation_query_get_stream_iterator_read_time_different_timezones( + custom_timezone, +): + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + + aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") + + # 1800 seconds after epoch in user-specified timezone + read_time = datetime.fromtimestamp(1800, tz=custom_timezone) + + # The internal firestore API needs to be initialized before it gets mocked. + client._firestore_api + + # Validate that the same timestamp_pb object would be sent in the actual request. + with mock.patch.object( + type(client._firestore_api_internal.transport.run_aggregation_query), "__call__" + ) as call: + call.return_value = iter([RunAggregationQueryResponse()]) + aggregation_query._get_stream_iterator( + transaction=None, retry=None, timeout=None, read_time=read_time + ) + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request_read_time = args[0].read_time + + # Verify that the timestamp is correct. + expected_timestamp = Timestamp(seconds=1800) + assert request_read_time.timestamp_pb() == expected_timestamp + + def _aggregation_query_get_helper( retry=None, timeout=None, - read_time=None, explain_options=None, + response_read_time=None, + query_read_time=None, ): from google.cloud._helpers import _datetime_to_pb_timestamp @@ -411,7 +496,11 @@ def _aggregation_query_get_helper( aggregation_query = make_aggregation_query(query) aggregation_query.count(alias="all") - aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + aggregation_result = AggregationResult( + alias="total", + value=5, + read_time=response_read_time, + ) if explain_options is not None: explain_metrics = {"execution_stats": {"results_returned": 1}} @@ -419,14 +508,18 @@ def _aggregation_query_get_helper( explain_metrics = None response_pb = make_aggregation_query_response( [aggregation_result], - read_time=read_time, + read_time=response_read_time, explain_metrics=explain_metrics, ) firestore_api.run_aggregation_query.return_value = iter([response_pb]) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = aggregation_query.get(**kwargs, explain_options=explain_options) + returned = aggregation_query.get( + **kwargs, + explain_options=explain_options, + read_time=query_read_time, + ) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -434,9 +527,9 @@ def _aggregation_query_get_helper( for r in result: assert r.alias == aggregation_result.alias assert r.value == aggregation_result.value - if read_time is not None: + if response_read_time is not None: result_datetime = _datetime_to_pb_timestamp(r.read_time) - assert result_datetime == read_time + assert result_datetime == response_read_time assert returned._explain_options == explain_options assert returned.explain_options == explain_options @@ -457,6 +550,8 @@ def _aggregation_query_get_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if query_read_time is not None: + expected_request["read_time"] = query_read_time # Verify the mock call. firestore_api.run_aggregation_query.assert_called_once_with( @@ -473,9 +568,11 @@ def test_aggregation_query_get(): def test_aggregation_query_get_with_readtime(): from google.cloud._helpers import _datetime_to_pb_timestamp - one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) - read_time = _datetime_to_pb_timestamp(one_hour_ago) - _aggregation_query_get_helper(read_time=read_time) + query_read_time = datetime.now(tz=timezone.utc) - timedelta(hours=1) + response_read_time = _datetime_to_pb_timestamp(query_read_time) + _aggregation_query_get_helper( + response_read_time=response_read_time, query_read_time=query_read_time + ) def test_aggregation_query_get_retry_timeout(): @@ -555,6 +652,7 @@ def _aggregation_query_stream_w_retriable_exc_helper( timeout=None, transaction=None, expect_retry=True, + read_time=None, ): from google.api_core import exceptions, gapic_v1 @@ -598,7 +696,9 @@ def _stream_w_exception(*_args, **_kw): query = make_query(parent) aggregation_query = make_aggregation_query(query) - get_response = aggregation_query.stream(transaction=transaction, **kwargs) + get_response = aggregation_query.stream( + transaction=transaction, **kwargs, read_time=read_time + ) assert isinstance(get_response, stream_generator.StreamGenerator) if expect_retry: @@ -629,23 +729,31 @@ def _stream_w_exception(*_args, **_kw): else: expected_transaction_id = None + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": expected_transaction_id, + } + if read_time is not None: + expected_request["read_time"] = read_time + assert calls[0] == mock.call( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": expected_transaction_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) if expect_retry: + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + } + if read_time is not None: + expected_request["read_time"] = read_time + assert calls[1] == mock.call( - request={ - "parent": parent_path, - "structured_aggregation_query": aggregation_query._to_protobuf(), - "transaction": None, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -661,6 +769,12 @@ def test_aggregation_query_stream_w_retriable_exc_w_retry(): _aggregation_query_stream_w_retriable_exc_helper(retry=retry, expect_retry=False) +def test_aggregation_query_stream_w_retriable_exc_w_read_time(): + _aggregation_query_stream_w_retriable_exc_helper( + read_time=datetime.now(tz=timezone.utc) + ) + + def test_aggregation_query_stream_w_retriable_exc_w_transaction(): from google.cloud.firestore_v1 import transaction @@ -713,7 +827,9 @@ def _aggregation_query_stream_helper( kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = aggregation_query.stream(**kwargs, explain_options=explain_options) + returned = aggregation_query.stream( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(returned, StreamGenerator) results = [] @@ -743,6 +859,8 @@ def _aggregation_query_stream_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if read_time is not None: + expected_request["read_time"] = read_time # Verify the mock call. firestore_api.run_aggregation_query.assert_called_once_with( @@ -756,7 +874,7 @@ def test_aggregation_query_stream(): _aggregation_query_stream_helper() -def test_aggregation_query_stream_with_readtime(): +def test_aggregation_query_stream_with_read_time(): from google.cloud._helpers import _datetime_to_pb_timestamp one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 6254c4c87f..9140f53e81 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -321,9 +321,39 @@ def test_async_aggregation_query_prep_stream_with_explain_options(): assert kwargs == {"retry": None} +def test_async_aggregation_query_prep_stream_with_read_time(): + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + + aggregation_query.count(alias="all") + aggregation_query.sum("someref", alias="sumall") + aggregation_query.avg("anotherref", alias="avgall") + + # 1800 seconds after epoch + read_time = datetime.now() + + request, kwargs = aggregation_query._prep_stream(read_time=read_time) + + parent_path, _ = parent._parent_info() + expected_request = { + "parent": parent_path, + "structured_aggregation_query": aggregation_query._to_protobuf(), + "transaction": None, + "read_time": read_time, + } + assert request == expected_request + assert kwargs == {"retry": None} + + @pytest.mark.asyncio async def _async_aggregation_query_get_helper( - retry=None, timeout=None, read_time=None, explain_options=None + retry=None, + timeout=None, + explain_options=None, + response_read_time=None, + query_read_time=None, ): from google.cloud._helpers import _datetime_to_pb_timestamp @@ -342,7 +372,11 @@ async def _async_aggregation_query_get_helper( aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias="all") - aggregation_result = AggregationResult(alias="total", value=5, read_time=read_time) + aggregation_result = AggregationResult( + alias="total", + value=5, + read_time=response_read_time, + ) if explain_options is not None: explain_metrics = {"execution_stats": {"results_returned": 1}} @@ -351,14 +385,18 @@ async def _async_aggregation_query_get_helper( response_pb = make_aggregation_query_response( [aggregation_result], - read_time=read_time, + read_time=response_read_time, explain_metrics=explain_metrics, ) firestore_api.run_aggregation_query.return_value = AsyncIter([response_pb]) kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = await aggregation_query.get(**kwargs, explain_options=explain_options) + returned = await aggregation_query.get( + **kwargs, + explain_options=explain_options, + read_time=query_read_time, + ) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -366,9 +404,9 @@ async def _async_aggregation_query_get_helper( for r in result: assert r.alias == aggregation_result.alias assert r.value == aggregation_result.value - if read_time is not None: + if response_read_time is not None: result_datetime = _datetime_to_pb_timestamp(r.read_time) - assert result_datetime == read_time + assert result_datetime == response_read_time if explain_options is None: with pytest.raises(QueryExplainError, match="explain_options not set"): @@ -387,6 +425,8 @@ async def _async_aggregation_query_get_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if query_read_time is not None: + expected_request["read_time"] = query_read_time firestore_api.run_aggregation_query.assert_called_once_with( request=expected_request, metadata=client._rpc_metadata, @@ -405,7 +445,9 @@ async def test_async_aggregation_query_get_with_readtime(): one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) read_time = _datetime_to_pb_timestamp(one_hour_ago) - await _async_aggregation_query_get_helper(read_time=read_time) + await _async_aggregation_query_get_helper( + query_read_time=one_hour_ago, response_read_time=read_time + ) @pytest.mark.asyncio @@ -583,7 +625,11 @@ async def _async_aggregation_query_stream_helper( kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) # Execute the query and check the response. - returned = aggregation_query.stream(**kwargs, explain_options=explain_options) + returned = aggregation_query.stream( + **kwargs, + explain_options=explain_options, + read_time=read_time, + ) assert isinstance(returned, AsyncStreamGenerator) results = [] @@ -611,6 +657,8 @@ async def _async_aggregation_query_stream_helper( } if explain_options is not None: expected_request["explain_options"] = explain_options._to_dict() + if read_time is not None: + expected_request["read_time"] = read_time # Verify the mock call. firestore_api.run_aggregation_query.assert_called_once_with( @@ -625,6 +673,15 @@ async def test_aggregation_query_stream(): await _async_aggregation_query_stream_helper() +@pytest.mark.asyncio +async def test_async_aggregation_query_stream_with_read_time(): + from google.cloud._helpers import _datetime_to_pb_timestamp + + one_hour_ago = datetime.now(tz=timezone.utc) - timedelta(hours=1) + read_time = _datetime_to_pb_timestamp(one_hour_ago) + await _async_aggregation_query_stream_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_aggregation_query_stream_w_explain_options_analyze_true(): from google.cloud.firestore_v1.query_profile import ExplainOptions diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index ee624d382b..9b49e5bf04 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -187,7 +187,7 @@ def test_asyncclient_document_factory_w_nested_path(): assert isinstance(document2, AsyncDocumentReference) -async def _collections_helper(retry=None, timeout=None): +async def _collections_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -206,7 +206,7 @@ async def __aiter__(self, **_): client._firestore_api_internal = firestore_api kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - collections = [c async for c in client.collections(**kwargs)] + collections = [c async for c in client.collections(read_time=read_time, **kwargs)] assert len(collections) == len(collection_ids) for collection, collection_id in zip(collections, collection_ids): @@ -215,8 +215,13 @@ async def __aiter__(self, **_): assert collection.id == collection_id base_path = client._database_string + "/documents" + expected_request = { + "parent": base_path, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -236,6 +241,12 @@ async def test_asyncclient_collections_w_retry_timeout(): await _collections_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncclient_collections_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _collections_helper(read_time=read_time) + + async def _invoke_get_all(client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["batch_get_documents"]) @@ -252,7 +263,13 @@ async def _invoke_get_all(client, references, document_pbs, **kwargs): return [s async for s in snapshots] -async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None): +async def _get_all_helper( + num_snapshots=2, + txn_id=None, + retry=None, + timeout=None, + read_time=None, +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import DocumentSnapshot from google.cloud.firestore_v1.types import common @@ -261,13 +278,13 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None data1 = {"a": "cheese"} document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) + document_pb1, doc_read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=doc_read_time) data2 = {"b": True, "c": 18} document2 = client.document("pineapple", "lamp2") - document, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document, read_time=read_time) + document, doc_read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=doc_read_time) document3 = client.document("pineapple", "lamp3") response3 = _make_batch_response(missing=document3._document_path) @@ -290,6 +307,7 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None documents, responses, field_paths=field_paths, + read_time=read_time, **kwargs, ) @@ -308,14 +326,17 @@ async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None mask = common.DocumentMask(field_paths=field_paths) kwargs.pop("transaction", None) + expected_request = { + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": txn_id, + } + if read_time is not None: + expected_request["read_time"] = read_time client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": mask, - "transaction": txn_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -347,6 +368,15 @@ async def test_asyncclient_get_all_wrong_order(): @pytest.mark.asyncio +async def test_asyncclient_get_all_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_all_helper(read_time=read_time) + + +@pytest.mark.asyncio +@pytest.mark.filterwarnings( + "ignore:coroutine method 'aclose' of 'AsyncIter' was never awaited:RuntimeWarning" +) async def test_asyncclient_get_all_unknown_result(): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 497fc455fa..a0194ace5b 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -17,6 +17,7 @@ import mock import pytest +from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT, make_async_client from tests.unit.v1.test__helpers import AsyncIter, AsyncMock @@ -302,7 +303,9 @@ async def _get_chunk(*args, **kwargs): @pytest.mark.asyncio -async def _list_documents_helper(page_size=None, retry=None, timeout=None): +async def _list_documents_helper( + page_size=None, retry=None, timeout=None, read_time=None +): from google.api_core.page_iterator import Page from google.api_core.page_iterator_async import AsyncIterator @@ -338,12 +341,13 @@ async def _next_page(self): documents = [ i async for i in collection.list_documents( - page_size=page_size, - **kwargs, + page_size=page_size, **kwargs, read_time=read_time ) ] else: - documents = [i async for i in collection.list_documents(**kwargs)] + documents = [ + i async for i in collection.list_documents(**kwargs, read_time=read_time) + ] # Verify the response and the mocks. assert len(documents) == len(document_ids) @@ -353,14 +357,17 @@ async def _next_page(self): assert document.id == document_id parent, _ = collection._parent_info() + expected_request = { + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + "mask": {"field_paths": None}, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.list_documents.assert_called_once_with( - request={ - "parent": parent, - "collection_id": collection.id, - "page_size": page_size, - "show_missing": True, - "mask": {"field_paths": None}, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -385,6 +392,11 @@ async def test_asynccollectionreference_list_documents_w_page_size(): await _list_documents_helper(page_size=25) +@pytest.mark.asyncio +async def test_asynccollectionreference_list_documents_w_read_time(): + await _list_documents_helper(read_time=datetime.now(tz=timezone.utc)) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_asynccollectionreference_get(query_class): @@ -450,6 +462,21 @@ async def test_asynccollectionreference_get_w_explain_options(query_class): ) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_get_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_async_collection_reference("collection") + await collection.get(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.get.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + + @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) @pytest.mark.asyncio async def test_asynccollectionreference_stream(query_class): @@ -552,6 +579,23 @@ async def response_generator(): assert explain_metrics.execution_stats.results_returned == 1 +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_stream_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_async_collection_reference("collection") + get_response = collection.stream(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + assert get_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + + def test_asynccollectionreference_recursive(): from google.cloud.firestore_v1.async_query import AsyncQuery diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 8d67e78f08..45472c6604 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -17,6 +17,9 @@ import mock import pytest +from datetime import datetime + +from google.protobuf import timestamp_pb2 from tests.unit.v1._test_helpers import make_async_client from tests.unit.v1.test__helpers import AsyncIter, AsyncMock @@ -399,6 +402,7 @@ async def _get_helper( return_empty=False, retry=None, timeout=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transaction import Transaction @@ -407,10 +411,14 @@ async def _get_helper( # Create a minimal fake GAPIC with a dummy response. create_time = 123 update_time = 234 - read_time = 345 + if read_time: + response_read_time = timestamp_pb2.Timestamp() + response_read_time.FromDatetime(read_time) + else: + response_read_time = 345 firestore_api = AsyncMock(spec=["batch_get_documents"]) response = mock.create_autospec(firestore.BatchGetDocumentsResponse) - response.read_time = 345 + response.read_time = response_read_time response.found = mock.create_autospec(document.Document) response.found.fields = {} response.found.create_time = create_time @@ -445,6 +453,7 @@ def WhichOneof(val): field_paths=field_paths, transaction=transaction, **kwargs, + read_time=read_time, ) assert snapshot.reference is document_reference @@ -457,7 +466,7 @@ def WhichOneof(val): else: assert snapshot.to_dict() == {} assert snapshot.exists - assert snapshot.read_time is read_time + assert snapshot.read_time is response_read_time assert snapshot.create_time is create_time assert snapshot.update_time is update_time @@ -472,13 +481,17 @@ def WhichOneof(val): else: expected_transaction_id = None + expected_request = { + "database": client._database_string, + "documents": [document_reference._document_path], + "mask": mask, + "transaction": expected_transaction_id, + } + if read_time is not None: + expected_request["read_time"] = read_time + firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": [document_reference._document_path], - "mask": mask, - "transaction": expected_transaction_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -530,7 +543,12 @@ async def test_asyncdocumentreference_get_with_transaction(): @pytest.mark.asyncio -async def _collections_helper(page_size=None, retry=None, timeout=None): +async def test_asyncdocumentreference_get_with_read_time(): + await _get_helper(read_time=datetime.now()) + + +@pytest.mark.asyncio +async def _collections_helper(page_size=None, retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -553,10 +571,15 @@ async def __aiter__(self, **_): document = _make_async_document_reference("where", "we-are", client=client) if page_size is not None: collections = [ - c async for c in document.collections(page_size=page_size, **kwargs) + c + async for c in document.collections( + page_size=page_size, **kwargs, read_time=read_time + ) ] else: - collections = [c async for c in document.collections(**kwargs)] + collections = [ + c async for c in document.collections(**kwargs, read_time=read_time) + ] # Verify the response and the mocks. assert len(collections) == len(collection_ids) @@ -565,8 +588,15 @@ async def __aiter__(self, **_): assert collection.parent == document assert collection.id == collection_id + expected_result = { + "parent": document._document_path, + "page_size": page_size, + } + if read_time is not None: + expected_result["read_time"] = read_time + firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, + request=expected_result, metadata=client._rpc_metadata, **kwargs, ) @@ -586,6 +616,11 @@ async def test_asyncdocumentreference_collections_w_retry_timeout(): await _collections_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_documentreference_collections_w_read_time(): + await _collections_helper(read_time=datetime.now()) + + @pytest.mark.asyncio async def test_asyncdocumentreference_collections_w_page_size(): await _collections_helper(page_size=10) diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index efc6c7df78..54c80e5ad4 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import types import mock @@ -41,7 +42,7 @@ def test_asyncquery_constructor(): assert not query._all_descendants -async def _get_helper(retry=None, timeout=None, explain_options=None): +async def _get_helper(retry=None, timeout=None, explain_options=None, read_time=None): from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC. @@ -68,7 +69,9 @@ async def _get_helper(retry=None, timeout=None, explain_options=None): # Execute the query and check the response. query = make_async_query(parent) - returned = await query.get(**kwargs, explain_options=explain_options) + returned = await query.get( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -94,6 +97,8 @@ async def _get_helper(retry=None, timeout=None, explain_options=None): } if explain_options: request["explain_options"] = explain_options._to_dict() + if read_time: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -117,6 +122,12 @@ async def test_asyncquery_get_w_retry_timeout(): await _get_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncquery_get_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asyncquery_get_limit_to_last(): from google.cloud import firestore @@ -336,7 +347,9 @@ async def test_asyncquery_chunkify_w_chunksize_gt_limit(): assert [snapshot.id for snapshot in chunks[0]] == expected_ids -async def _stream_helper(retry=None, timeout=None, explain_options=None): +async def _stream_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -367,7 +380,9 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None): # Execute the query and check the response. query = make_async_query(parent) - stream_response = query.stream(**kwargs, explain_options=explain_options) + stream_response = query.stream( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(stream_response, AsyncStreamGenerator) returned = [x async for x in stream_response] @@ -395,6 +410,8 @@ async def _stream_helper(retry=None, timeout=None, explain_options=None): } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -418,6 +435,12 @@ async def test_asyncquery_stream_w_retry_timeout(): await _stream_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asyncquery_stream_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _stream_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asyncquery_stream_with_limit_to_last(): # Attach the fake GAPIC to a real client. @@ -481,6 +504,57 @@ async def test_asyncquery_stream_with_transaction(): ) +@pytest.mark.asyncio +async def test_asyncquery_stream_with_transaction_and_read_time(): + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Create a read_time for this client. + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + # Execute the query and check the response. + query = make_async_query(parent) + get_response = query.stream(transaction=transaction, read_time=read_time) + assert isinstance(get_response, AsyncStreamGenerator) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("declaration", "burger") + assert snapshot.to_dict() == data + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + "read_time": read_time, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.asyncio async def test_asyncquery_stream_no_results(): from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -718,7 +792,7 @@ def test_asynccollectiongroup_constructor_all_descendents_is_false(): @pytest.mark.asyncio -async def _get_partitions_helper(retry=None, timeout=None): +async def _get_partitions_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC. @@ -743,7 +817,7 @@ async def _get_partitions_helper(retry=None, timeout=None): # Execute the query and check the response. query = _make_async_collection_group(parent) - get_response = query.get_partitions(2, **kwargs) + get_response = query.get_partitions(2, read_time=read_time, **kwargs) assert isinstance(get_response, types.AsyncGeneratorType) returned = [i async for i in get_response] @@ -755,12 +829,15 @@ async def _get_partitions_helper(retry=None, timeout=None): parent, orders=(query._make_order("__name__", query.ASCENDING),), ) + expected_request = { + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.partition_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": partition_query._to_protobuf(), - "partition_count": 2, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -780,6 +857,12 @@ async def test_asynccollectiongroup_get_partitions_w_retry_timeout(): await _get_partitions_helper(retry=retry, timeout=timeout) +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_partitions_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asynccollectiongroup_get_partitions_w_filter(): # Make a **real** collection reference as parent. diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index e4bb788e3d..d357e3482a 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import mock import pytest @@ -294,13 +295,15 @@ async def test_asynctransaction__commit_failure(): ) -async def _get_all_helper(retry=None, timeout=None): +async def _get_all_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers client = AsyncMock(spec=["get_all"]) transaction = _make_async_transaction(client) ref1, ref2 = mock.Mock(), mock.Mock() kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time result = await transaction.get_all([ref1, ref2], **kwargs) @@ -326,7 +329,15 @@ async def test_asynctransaction_get_all_w_retry_timeout(): await _get_all_helper(retry=retry, timeout=timeout) -async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=None): +@pytest.mark.asyncio +async def test_asynctransaction_get_all_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_all_helper(read_time=read_time) + + +async def _get_w_document_ref_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import AsyncDocumentReference @@ -335,7 +346,12 @@ async def _get_w_document_ref_helper(retry=None, timeout=None, explain_options=N ref = AsyncDocumentReference("documents", "doc-id") kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - result = await transaction.get(ref, **kwargs, explain_options=explain_options) + if explain_options is not None: + kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time + + result = await transaction.get(ref, **kwargs) client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) assert result is client.get_all.return_value @@ -356,7 +372,7 @@ async def test_asynctransaction_get_w_document_ref_w_retry_timeout(): @pytest.mark.asyncio -async def test_transaction_get_w_document_ref_w_explain_options(): +async def test_asynctransaction_get_w_document_ref_w_explain_options(): from google.cloud.firestore_v1.query_profile import ExplainOptions with pytest.raises(ValueError, match="`explain_options` cannot be provided."): @@ -365,7 +381,16 @@ async def test_transaction_get_w_document_ref_w_explain_options(): ) -async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): +@pytest.mark.asyncio +async def test_asynctransaction_get_w_document_ref_w_read_time(): + await _get_w_document_ref_helper( + read_time=datetime.datetime.now(tz=datetime.timezone.utc) + ) + + +async def _get_w_query_helper( + retry=None, timeout=None, explain_options=None, read_time=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_query import AsyncQuery from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -407,6 +432,7 @@ async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): query, **kwargs, explain_options=explain_options, + read_time=read_time, ) # Verify the response. @@ -435,6 +461,8 @@ async def _get_w_query_helper(retry=None, timeout=None, explain_options=None): } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -462,6 +490,12 @@ async def test_transaction_get_w_query_w_explain_options(): await _get_w_query_helper(explain_options=ExplainOptions(analyze=True)) +@pytest.mark.asyncio +async def test_asynctransaction_get_w_query_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + await _get_w_query_helper(read_time=read_time) + + @pytest.mark.asyncio async def test_asynctransaction_get_failure(): client = _make_client() diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7f6b0e5e2e..7804b0430f 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1400,6 +1400,19 @@ def test_basequery__normalize_cursor_as_snapshot_hit(): assert query._normalize_cursor(cursor, query._orders) == ([1], True) +def test_basequery__normalize_cursor_non_existant_snapshot(): + from google.cloud.firestore_v1 import document + + values = {"b": 1} + docref = _make_docref("here", "doc_id") + snapshot = document.DocumentSnapshot(docref, values, False, None, None, None) + cursor = (snapshot, True) + collection = _make_collection("here") + query = _make_base_query(collection).order_by("b", "ASCENDING") + + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + def test_basequery__normalize_cursor_w___name___w_reference(): db_string = "projects/my-project/database/(default)" client = mock.Mock(spec=["_database_string"]) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index edb411c9ff..df3ae15b41 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -281,7 +281,7 @@ def test_client_document_factory_w_nested_path(database): assert isinstance(document2, DocumentReference) -def _collections_helper(retry=None, timeout=None, database=None): +def _collections_helper(retry=None, timeout=None, database=None, read_time=None): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.collection import CollectionReference @@ -298,7 +298,7 @@ def __iter__(self): client._firestore_api_internal = firestore_api kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - collections = list(client.collections(**kwargs)) + collections = list(client.collections(read_time=read_time, **kwargs)) assert len(collections) == len(collection_ids) for collection, collection_id in zip(collections, collection_ids): @@ -307,8 +307,13 @@ def __iter__(self): assert collection.id == collection_id base_path = client._database_string + "/documents" + expected_request = { + "parent": base_path, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -328,6 +333,12 @@ def test_client_collections_w_retry_timeout(database): _collections_helper(retry=retry, timeout=timeout, database=database) +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_collections_read_time(database): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _collections_helper(database=database, read_time=read_time) + + def _invoke_get_all(client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["batch_get_documents"]) @@ -345,7 +356,12 @@ def _invoke_get_all(client, references, document_pbs, **kwargs): def _get_all_helper( - num_snapshots=2, txn_id=None, retry=None, timeout=None, database=None + num_snapshots=2, + txn_id=None, + retry=None, + timeout=None, + database=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.async_document import DocumentSnapshot @@ -355,13 +371,13 @@ def _get_all_helper( data1 = {"a": "cheese"} document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) + document_pb1, doc_read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=doc_read_time) data2 = {"b": True, "c": 18} document2 = client.document("pineapple", "lamp2") - document, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document, read_time=read_time) + document, doc_read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=doc_read_time) document3 = client.document("pineapple", "lamp3") response3 = _make_batch_response(missing=document3._document_path) @@ -384,6 +400,7 @@ def _get_all_helper( documents, responses, field_paths=field_paths, + read_time=read_time, **kwargs, ) @@ -402,14 +419,17 @@ def _get_all_helper( mask = common.DocumentMask(field_paths=field_paths) kwargs.pop("transaction", None) + expected_request = { + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": txn_id, + } + if read_time is not None: + expected_request["read_time"] = read_time client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": mask, - "transaction": txn_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -440,6 +460,12 @@ def test_client_get_all_wrong_order(database): _get_all_helper(num_snapshots=3, database=database) +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_get_all_read_time(database): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _get_all_helper(database=database, read_time=read_time) + + @pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) def test_client_get_all_unknown_result(database): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 29f76108d1..da91651b95 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -16,6 +16,7 @@ import mock +from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -266,7 +267,7 @@ def test_add_w_retry_timeout(): _add_helper(retry=retry, timeout=timeout) -def _list_documents_helper(page_size=None, retry=None, timeout=None): +def _list_documents_helper(page_size=None, retry=None, timeout=None, read_time=None): from google.api_core.page_iterator import Iterator, Page from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers @@ -299,9 +300,15 @@ def _next_page(self): kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) if page_size is not None: - documents = list(collection.list_documents(page_size=page_size, **kwargs)) + documents = list( + collection.list_documents( + page_size=page_size, + **kwargs, + read_time=read_time, + ) + ) else: - documents = list(collection.list_documents(**kwargs)) + documents = list(collection.list_documents(**kwargs, read_time=read_time)) # Verify the response and the mocks. assert len(documents) == len(document_ids) @@ -311,14 +318,18 @@ def _next_page(self): assert document.id == document_id parent, _ = collection._parent_info() + expected_request = { + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + "mask": {"field_paths": None}, + } + if read_time is not None: + expected_request["read_time"] = read_time + api_client.list_documents.assert_called_once_with( - request={ - "parent": parent, - "collection_id": collection.id, - "page_size": page_size, - "show_missing": True, - "mask": {"field_paths": None}, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -340,6 +351,10 @@ def test_list_documents_w_page_size(): _list_documents_helper(page_size=25) +def test_list_documents_w_read_time(): + _list_documents_helper(read_time=datetime.now()) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) def test_get(query_class): collection = _make_collection_reference("collection") @@ -403,6 +418,22 @@ def test_get_w_explain_options(query_class): ) +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_get_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_collection_reference("collection") + get_response = collection.get(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + + @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) def test_stream(query_class): collection = _make_collection_reference("collection") @@ -463,6 +494,22 @@ def test_stream_w_explain_options(query_class): ) +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_stream_w_read_time(query_class): + read_time = datetime.now(tz=timezone.utc) + collection = _make_collection_reference("collection") + get_response = collection.stream(read_time=read_time) + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + + assert get_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with( + transaction=None, + read_time=read_time, + ) + + @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index b9116ae61d..3a2a3701e0 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -16,6 +16,9 @@ import mock import pytest +from datetime import datetime + +from google.protobuf import timestamp_pb2 from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -393,6 +396,7 @@ def _get_helper( retry=None, timeout=None, database=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.transaction import Transaction @@ -401,10 +405,14 @@ def _get_helper( # Create a minimal fake GAPIC with a dummy response. create_time = 123 update_time = 234 - read_time = 345 + if read_time: + response_read_time = timestamp_pb2.Timestamp() + response_read_time.FromDatetime(read_time) + else: + response_read_time = 345 firestore_api = mock.Mock(spec=["batch_get_documents"]) response = mock.create_autospec(firestore.BatchGetDocumentsResponse) - response.read_time = read_time + response.read_time = response_read_time response.found = mock.create_autospec(document.Document) response.found.fields = {} response.found.create_time = create_time @@ -435,7 +443,10 @@ def WhichOneof(val): kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) snapshot = document_reference.get( - field_paths=field_paths, transaction=transaction, **kwargs + field_paths=field_paths, + transaction=transaction, + **kwargs, + read_time=read_time, ) assert snapshot.reference is document_reference @@ -448,7 +459,7 @@ def WhichOneof(val): else: assert snapshot.to_dict() == {} assert snapshot.exists - assert snapshot.read_time is read_time + assert snapshot.read_time is response_read_time assert snapshot.create_time is create_time assert snapshot.update_time is update_time @@ -463,13 +474,17 @@ def WhichOneof(val): else: expected_transaction_id = None + expected_request = { + "database": client._database_string, + "documents": [document_reference._document_path], + "mask": mask, + "transaction": expected_transaction_id, + } + if read_time is not None: + expected_request["read_time"] = read_time + firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": [document_reference._document_path], - "mask": mask, - "transaction": expected_transaction_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -520,7 +535,14 @@ def test_documentreference_get_with_transaction(database): _get_helper(use_transaction=True, database=database) -def _collections_helper(page_size=None, retry=None, timeout=None, database=None): +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_documentreference_get_with_read_time(database): + _get_helper(read_time=datetime.now(), database=database) + + +def _collections_helper( + page_size=None, retry=None, timeout=None, read_time=None, database=None +): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.services.firestore.client import FirestoreClient @@ -541,9 +563,11 @@ def __iter__(self): # Actually make a document and call delete(). document = _make_document_reference("where", "we-are", client=client) if page_size is not None: - collections = list(document.collections(page_size=page_size, **kwargs)) + collections = list( + document.collections(page_size=page_size, **kwargs, read_time=read_time) + ) else: - collections = list(document.collections(**kwargs)) + collections = list(document.collections(**kwargs, read_time=read_time)) # Verify the response and the mocks. assert len(collections) == len(collection_ids) @@ -552,8 +576,15 @@ def __iter__(self): assert collection.parent == document assert collection.id == collection_id + expected_result = { + "parent": document._document_path, + "page_size": page_size, + } + if read_time is not None: + expected_result["read_time"] = read_time + api_client.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, + request=expected_result, metadata=client._rpc_metadata, **kwargs, ) @@ -578,6 +609,11 @@ def test_documentreference_collections_w_retry_timeout(database): _collections_helper(retry=retry, timeout=timeout, database=database) +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_documentreference_collections_w_read_time(database): + _collections_helper(read_time=datetime.now(), database=database) + + @mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) def test_documentreference_on_snapshot(watch): client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index f30a4fcdff..b8c37cf848 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import types import mock @@ -42,6 +43,7 @@ def _query_get_helper( timeout=None, database=None, explain_options=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers @@ -71,7 +73,7 @@ def _query_get_helper( # Execute the query and check the response. query = make_query(parent) - returned = query.get(**kwargs, explain_options=explain_options) + returned = query.get(**kwargs, explain_options=explain_options, read_time=read_time) assert isinstance(returned, QueryResultsList) assert len(returned) == 1 @@ -97,6 +99,8 @@ def _query_get_helper( } if explain_options: request["explain_options"] = explain_options._to_dict() + if read_time: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -118,6 +122,11 @@ def test_query_get_w_retry_timeout(): _query_get_helper(retry=retry, timeout=timeout) +def test_query_get_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _query_get_helper(read_time=read_time) + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_get_limit_to_last(database): from google.cloud import firestore @@ -338,6 +347,7 @@ def _query_stream_helper( timeout=None, database=None, explain_options=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.stream_generator import StreamGenerator @@ -369,7 +379,9 @@ def _query_stream_helper( # Execute the query and check the response. query = make_query(parent) - get_response = query.stream(**kwargs, explain_options=explain_options) + get_response = query.stream( + **kwargs, explain_options=explain_options, read_time=read_time + ) assert isinstance(get_response, StreamGenerator) returned = list(get_response) @@ -396,6 +408,8 @@ def _query_stream_helper( } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -417,6 +431,11 @@ def test_query_stream_w_retry_timeout(): _query_stream_helper(retry=retry, timeout=timeout) +def test_query_stream_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _query_stream_helper(read_time=read_time) + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_with_limit_to_last(database): # Attach the fake GAPIC to a real client. @@ -480,6 +499,57 @@ def test_query_stream_with_transaction(database): ) +@pytest.mark.parametrize("database", [None, "somedb"]) +def test_query_stream_with_transaction_and_read_time(database): + from google.cloud.firestore_v1.stream_generator import StreamGenerator + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = make_client(database=database) + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Create a read_time for this client. + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = make_query(parent) + get_response = query.stream(transaction=transaction, read_time=read_time) + assert isinstance(get_response, StreamGenerator) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("declaration", "burger") + assert snapshot.to_dict() == data + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + "read_time": read_time, + }, + metadata=client._rpc_metadata, + ) + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_query_stream_no_results(database): from google.cloud.firestore_v1.stream_generator import StreamGenerator @@ -690,7 +760,12 @@ def test_query_stream_w_collection_group(database): def _query_stream_w_retriable_exc_helper( - retry=_not_passed, timeout=None, transaction=None, expect_retry=True, database=None + retry=_not_passed, + timeout=None, + transaction=None, + expect_retry=True, + database=None, + read_time=None, ): from google.api_core import exceptions, gapic_v1 @@ -734,7 +809,7 @@ def _stream_w_exception(*_args, **_kw): # Execute the query and check the response. query = make_query(parent) - get_response = query.stream(transaction=transaction, **kwargs) + get_response = query.stream(transaction=transaction, read_time=read_time, **kwargs) assert isinstance(get_response, StreamGenerator) if expect_retry: @@ -763,24 +838,31 @@ def _stream_w_exception(*_args, **_kw): else: expected_transaction_id = None + expected_request = { + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": expected_transaction_id, + } + if read_time is not None: + expected_request["read_time"] = read_time + assert calls[0] == mock.call( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": expected_transaction_id, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) if expect_retry: new_query = query.start_after(snapshot) + expected_request = { + "parent": parent_path, + "structured_query": new_query._to_protobuf(), + "transaction": None, + } + if read_time is not None: + expected_request["read_time"] = read_time assert calls[1] == mock.call( - request={ - "parent": parent_path, - "structured_query": new_query._to_protobuf(), - "transaction": None, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -804,6 +886,11 @@ def test_query_stream_w_retriable_exc_w_transaction(): _query_stream_w_retriable_exc_helper(transaction=txn) +def test_query_stream_w_retriable_exc_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _query_stream_w_retriable_exc_helper(read_time=read_time) + + def test_query_stream_w_explain_options(): from google.cloud.firestore_v1.query_profile import ExplainOptions @@ -842,7 +929,9 @@ def test_collection_group_constructor_all_descendents_is_false(): _make_collection_group(mock.sentinel.parent, all_descendants=False) -def _collection_group_get_partitions_helper(retry=None, timeout=None, database=None): +def _collection_group_get_partitions_helper( + retry=None, timeout=None, database=None, read_time=None +): from google.cloud.firestore_v1 import _helpers # Create a minimal fake GAPIC. @@ -868,7 +957,7 @@ def _collection_group_get_partitions_helper(retry=None, timeout=None, database=N # Execute the query and check the response. query = _make_collection_group(parent) - get_response = query.get_partitions(2, **kwargs) + get_response = query.get_partitions(2, read_time=read_time, **kwargs) assert isinstance(get_response, types.GeneratorType) returned = list(get_response) @@ -880,12 +969,15 @@ def _collection_group_get_partitions_helper(retry=None, timeout=None, database=N parent, orders=(query._make_order("__name__", query.ASCENDING),), ) + expected_request = { + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + } + if read_time is not None: + expected_request["read_time"] = read_time firestore_api.partition_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": partition_query._to_protobuf(), - "partition_count": 2, - }, + request=expected_request, metadata=client._rpc_metadata, **kwargs, ) @@ -903,6 +995,11 @@ def test_collection_group_get_partitions_w_retry_timeout(): _collection_group_get_partitions_helper(retry=retry, timeout=timeout) +def test_collection_group_get_partitions_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _collection_group_get_partitions_helper(read_time=read_time) + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_collection_group_get_partitions_w_filter(database): # Make a **real** collection reference as parent. diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 941e294dbd..2fe215abc9 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import mock import pytest @@ -312,13 +313,15 @@ def test_transaction__commit_failure(database): ) -def _transaction_get_all_helper(retry=None, timeout=None): +def _transaction_get_all_helper(retry=None, timeout=None, read_time=None): from google.cloud.firestore_v1 import _helpers client = mock.Mock(spec=["get_all"]) transaction = _make_transaction(client) ref1, ref2 = mock.Mock(), mock.Mock() kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + if read_time is not None: + kwargs["read_time"] = read_time result = transaction.get_all([ref1, ref2], **kwargs) @@ -342,10 +345,16 @@ def test_transaction_get_all_w_retry_timeout(): _transaction_get_all_helper(retry=retry, timeout=timeout) +def test_transaction_get_all_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _transaction_get_all_helper(read_time=read_time) + + def _transaction_get_w_document_ref_helper( retry=None, timeout=None, explain_options=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.document import DocumentReference @@ -357,6 +366,8 @@ def _transaction_get_w_document_ref_helper( if explain_options is not None: kwargs["explain_options"] = explain_options + if read_time is not None: + kwargs["read_time"] = read_time result = transaction.get(ref, **kwargs) @@ -388,10 +399,17 @@ def test_transaction_get_w_document_ref_w_explain_options(): ) +def test_transaction_get_w_document_ref_w_read_time(): + _transaction_get_w_document_ref_helper( + read_time=datetime.datetime.now(tz=datetime.timezone.utc) + ) + + def _transaction_get_w_query_helper( retry=None, timeout=None, explain_options=None, + read_time=None, ): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.query import Query @@ -434,6 +452,7 @@ def _transaction_get_w_query_helper( query, **kwargs, explain_options=explain_options, + read_time=read_time, ) # Verify the response. @@ -462,6 +481,8 @@ def _transaction_get_w_query_helper( } if explain_options is not None: request["explain_options"] = explain_options._to_dict() + if read_time is not None: + request["read_time"] = read_time # Verify the mock call. firestore_api.run_query.assert_called_once_with( @@ -489,6 +510,11 @@ def test_transaction_get_w_query_w_explain_options(): _transaction_get_w_query_helper(explain_options=ExplainOptions(analyze=True)) +def test_transaction_get_w_query_w_read_time(): + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + _transaction_get_w_query_helper(read_time=read_time) + + @pytest.mark.parametrize("database", [None, "somedb"]) def test_transaction_get_failure(database): client = _make_client(database=database) diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 6d8c12abc0..63e2233a4f 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -322,6 +322,15 @@ def test_watch_close(): assert inst._closed +def test_watch_close_w_empty_attrs(): + inst = _make_watch() + inst._consumer = None + inst._rpc = None + inst.close() + assert inst._consumer is None + assert inst._rpc is None + + def test_watch__get_rpc_request_wo_resume_token(): inst = _make_watch()